diff --git a/.github/workflows/pull.yml b/.github/workflows/pull.yml index 084cfe17a4d..1500fceebb2 100644 --- a/.github/workflows/pull.yml +++ b/.github/workflows/pull.yml @@ -932,6 +932,7 @@ jobs: # Custom operator tests PYTHON_EXECUTABLE=python bash backends/vulkan/test/custom_ops/build_and_run.sh add ./cmake-out/backends/vulkan/test/custom_ops/q8csw_linear + ./cmake-out/backends/vulkan/test/custom_ops/q8csw_conv2d nxp-build-test: name: nxp-build-test diff --git a/backends/vulkan/runtime/graph/ops/glsl/col2im.glsl b/backends/vulkan/runtime/graph/ops/glsl/col2im.glsl new file mode 100644 index 00000000000..c105ef18719 --- /dev/null +++ b/backends/vulkan/runtime/graph/ops/glsl/col2im.glsl @@ -0,0 +1,94 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#version 450 core + +#define PRECISION ${PRECISION} +#define VEC4_T ${texel_load_type(DTYPE, OUTPUT_STORAGE)} +#define T ${texel_load_component_type(DTYPE, OUTPUT_STORAGE)} + +$if OUTPUT_STORAGE == "buffer": + #define OUTPUT_BUFFER +$if INPUT_STORAGE == "buffer": + #define INPUT_BUFFER + +#define TILE_M4 1 +#define TILE_N4 1 +#define TILE_K4 1 + +#define TILE_M 4 +#define TILE_N 4 +#define TILE_K 4 + +${define_required_extensions(DTYPE)} + +layout(std430) buffer; + +#include "conv2d_common.glslh" + +${layout_declare_tensor(B, "w", "t_output", DTYPE, OUTPUT_STORAGE, is_scalar_array=False)} +${layout_declare_tensor(B, "r", "t_input", DTYPE, INPUT_STORAGE, is_scalar_array=False)} + +// Sizes of the convolution output image +${layout_declare_ubo(B, "ivec4", "output_sizes")} +// Sizes of the convolution input image +${layout_declare_ubo(B, "ivec4", "input_sizes")} +// Sizes of the im2col matrix of the convolution output +${layout_declare_ubo(B, "ivec4", "matrix_sizes")} + +${layout_declare_ubo(B, "Conv2DParams", "conv2d_params")} + +layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in; + +#include "conv2d_fp_im2col_block_store.glslh" + +#ifdef INPUT_BUFFER + +void load_matrix_tile( + out FPOutTile tile, + const int n4, + const int m_start, + const int N4) { + [[unroll]] for (int m = 0; m < TILE_M; m++) { + tile.data[m][0] = t_input[(m_start + m) * N4 + n4]; + } +} + +#else // INPUT_TEXTURE + +void load_matrix_tile( + out FPOutTile tile, + const int n4, + const int m_start, + const int N4) { + [[unroll]] for (int m = 0; m < TILE_M; m++) { + tile.data[m][0] = texelFetch( + t_input, ivec3(n4, m_start + m, 0), 0); + } +} + +#endif // INPUT_BUFFER + +void main() { + // Each thread loads and writes a 4 wide x 4 high block of the matrix + const int n4 = int(gl_GlobalInvocationID.x); + const int m4 = int(gl_GlobalInvocationID.y); + + const int n = mul_4(n4); + const int m = mul_4(m4); + + if (n >= matrix_sizes.x || m >= matrix_sizes.y) { + return; + } + + FPOutTile tile; + + const int N4 = div_4(matrix_sizes.x); + load_matrix_tile(tile, n4, m, N4); + write_im2col_tile_as_image(tile, n4, m); +} diff --git a/backends/vulkan/runtime/graph/ops/glsl/col2im.yaml b/backends/vulkan/runtime/graph/ops/glsl/col2im.yaml new file mode 100644 index 00000000000..b6d0972271a --- /dev/null +++ b/backends/vulkan/runtime/graph/ops/glsl/col2im.yaml @@ -0,0 +1,19 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +col2im: + parameter_names_with_default_values: + DTYPE: float + OUTPUT_STORAGE: texture3d + INPUT_STORAGE: buffer + generate_variant_forall: + DTYPE: + - VALUE: half + - VALUE: float + shader_variants: + - NAME: col2im_texture3d_buffer + - NAME: col2im_texture3d_texture3d + INPUT_STORAGE: texture3d diff --git a/backends/vulkan/runtime/graph/ops/glsl/conv2d_common.glslh b/backends/vulkan/runtime/graph/ops/glsl/conv2d_common.glslh new file mode 100644 index 00000000000..41825cba867 --- /dev/null +++ b/backends/vulkan/runtime/graph/ops/glsl/conv2d_common.glslh @@ -0,0 +1,51 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#ifndef CONV2D_COMMON_GLSLH +#define CONV2D_COMMON_GLSLH + +#include "common.glslh" + +struct Conv2DParams { + ivec2 kernel_size; + ivec2 stride; + ivec2 padding; + ivec2 dilation; + int groups; + int out_channels_per_group; + int in_channels_per_group; + int logical_K_per_group; + int K_per_group; + int K4_per_group; + int logical_K; + int K; + int K4; +}; + +#ifdef DEBUG_MODE + +void printConv2DParams(const Conv2DParams params) { + debugPrintfEXT("Conv2DParams: \\n"); + debugPrintfEXT( + " kernel_size: %d, %d\\n", params.kernel_size.x, params.kernel_size.y); + debugPrintfEXT(" stride: %d, %d\\n", params.stride.x, params.stride.y); + debugPrintfEXT(" padding: %d, %d\\n", params.padding.x, params.padding.y); + debugPrintfEXT(" dilation: %d, %d\\n", params.dilation.x, params.dilation.y); + debugPrintfEXT(" groups: %d\\n", params.groups); + debugPrintfEXT( + " out_channels_per_group: %d\\n", params.out_channels_per_group); + debugPrintfEXT( + " in_channels_per_group: %d\\n", params.in_channels_per_group); + debugPrintfEXT(" logical_K_per_group: %d\\n", params.logical_K_per_group); + debugPrintfEXT(" K_per_group: %d\\n", params.K_per_group); + debugPrintfEXT(" K4_per_group: %d\\n", params.K4_per_group); +} + +#endif // DEBUG_MODE + +#endif // CONV2D_COMMON_GLSLH diff --git a/backends/vulkan/runtime/graph/ops/glsl/conv2d_fp_im2col_block.glslh b/backends/vulkan/runtime/graph/ops/glsl/conv2d_fp_im2col_block.glslh new file mode 100644 index 00000000000..7add8c4cd16 --- /dev/null +++ b/backends/vulkan/runtime/graph/ops/glsl/conv2d_fp_im2col_block.glslh @@ -0,0 +1,96 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#ifndef CONV2D_FP_IM2COL_BLOCK +#define CONV2D_FP_IM2COL_BLOCK + +/* + * Defines utilities to convert between (col, row) indices of an im2col matrix + * and 4-dimension tensor indices of image tensors. + * + * Requires: + * - output_sizes to be defined in the shader layout, corresponding to the sizes + * of the output image of the convolution op. + * - image_sizes to be defined in the shader layout, corresponding to the sizes + * of the input image of the convolution op. + * - conv2d_params to be defined in the shader layout + */ + +#extension GL_EXT_control_flow_attributes : require + +#include "common.glslh" +#include "conv2d_common.glslh" + +struct Im2ColMatrixIdx { + int row; + int col; + // Relevant for grouped convolution. This indicates the column index relative + // to the first column in the group. + int col_idx_in_group; + int group_idx; +}; + +void unwrap_m(out TensorIndex4D out_tidx_base, const int m) { + out_tidx_base.data[3] = m / (output_sizes.y * output_sizes.x); + out_tidx_base.data[1] = (m / output_sizes.x) % output_sizes.y; + out_tidx_base.data[0] = m % output_sizes.x; + + // Initialize channels to 0; assume it will be set later on + out_tidx_base.data[2] = 0; +} + +void im2col_tidx_to_output_tidx( + out TensorIndex4D output_tidx, + const Im2ColMatrixIdx im2col_tidx) { + unwrap_m(output_tidx, im2col_tidx.row); + // Set channels + output_tidx.data.z = im2col_tidx.col; +} + +/* + * Converts im2col matrix position to corresponding 4D tensor index, accounting + * for grouped convolutions. The conversion should ensure that all data within + * the same group occupy a contiguous block in memory. + */ +void im2col_idx_to_input_tidx( + out TensorIndex4D input_tidx, + const Im2ColMatrixIdx im2col_idx) { + TensorIndex4D output_tidx; + unwrap_m(output_tidx, im2col_idx.row); + + const int in_channels_per_group = conv2d_params.in_channels_per_group; + // Determine the corresponding position within the convolution window based + // on the col index (more specifically, the col index within the group) + const int channel_within_group = + im2col_idx.col_idx_in_group % in_channels_per_group; + const int kernel_x = (im2col_idx.col_idx_in_group / in_channels_per_group) % + conv2d_params.kernel_size.x; + const int kernel_y = im2col_idx.col_idx_in_group / + (in_channels_per_group * conv2d_params.kernel_size.x); + + // Calculate the actual input channel index + const int channel_idx = + im2col_idx.group_idx * conv2d_params.in_channels_per_group + + channel_within_group; + + // Calculate corresponding input coordinates based on output position + // associated with the row index. + const int input_y = int(output_tidx.data.y * conv2d_params.stride.y) - + int(conv2d_params.padding.y) + int(kernel_y * conv2d_params.dilation.y); + const int input_x = int(output_tidx.data.x * conv2d_params.stride.x) - + int(conv2d_params.padding.x) + int(kernel_x * conv2d_params.dilation.x); + + input_tidx.data = ivec4(input_x, input_y, channel_idx, output_tidx.data.w); +} + +// 4x4 block of the im2col matrix +struct FPIm2ColBlock { + VEC4_T data[4]; +}; + +#endif // CONV2D_FP_IM2COL_BLOCK diff --git a/backends/vulkan/runtime/graph/ops/glsl/conv2d_fp_im2col_block_load.glslh b/backends/vulkan/runtime/graph/ops/glsl/conv2d_fp_im2col_block_load.glslh new file mode 100644 index 00000000000..c02b070e17e --- /dev/null +++ b/backends/vulkan/runtime/graph/ops/glsl/conv2d_fp_im2col_block_load.glslh @@ -0,0 +1,169 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#ifndef CONV2D_FP_IM2COL_BLOCK_LOAD +#define CONV2D_FP_IM2COL_BLOCK_LOAD + +/* + * Defines utilities to load data for a 4x4 im2col matrix block from an + * input image and store the data as a FPInputTile. + * + * Requires: + * - t_input to be defined in the shader layout, representing the texture of the + * source image + * - conv2d_params to be defined in the shader layout + */ + +#extension GL_EXT_control_flow_attributes : require + +#extension GL_EXT_debug_printf : require + +#include "common.glslh" +#include "conv2d_common.glslh" +#include "conv2d_fp_im2col_block.glslh" +#include "linear_fp_input_tile.glslh" + +VEC4_T load_input_texel(const TensorIndex4D tidx) { + // Assumes batch size is 1 and channels packing + return texelFetch( + t_input, ivec3(tidx.data.x, tidx.data.y, div_4(tidx.data.z)), 0); +} + +T load_input_texel_element(const TensorIndex4D tidx) { + const int channels_texel_idx = div_4(tidx.data.z); + const int texel_comp = mod_4(tidx.data.z); + // Assumes batch size is 1 and channels packing + return texelFetch( + t_input, + ivec3(tidx.data.x, tidx.data.y, channels_texel_idx), + 0)[texel_comp]; +} + +// k4 -> group of 4 input channels idx +// m -> flattened batch, output width, output height dim idx +/* + * Fast impl for when the input image's channels per group is a multiple of 4. + * In this case, it is guaranteed that a texel loaded from the input can be + * stored directly to the output without any additional filtering. + */ +void load_im2col_block_fast( + out FPIm2ColBlock block, + const int k4, + const int m4, + const int logical_K, + const int M) { + Im2ColMatrixIdx im2col_idx; + im2col_idx.col = mul_4(k4); // k + im2col_idx.row = mul_4(m4); // m + + // Due to the assumption that in_channels_per_group % 4 == 0, it is + // guaranteed that the next 4 columns (including this one) is part of the + // same group. + im2col_idx.group_idx = im2col_idx.col / conv2d_params.K_per_group; + im2col_idx.col_idx_in_group = im2col_idx.col % conv2d_params.K_per_group; + + [[unroll]] for (int m_off = 0; m_off < 4; ++m_off) { + if (im2col_idx.row >= M) { + block.data[m_off] = VEC4_T(0); + continue; + } + + TensorIndex4D input_tidx; + im2col_idx_to_input_tidx(input_tidx, im2col_idx); + + // Load the texel + block.data[m_off] = load_input_texel(input_tidx); + + im2col_idx.row++; + } +} + +/* + * If input image channels is not a multiple of 4, then it is likely that for + * some matrix texels, the source data is split between different texels of the + * source image. In this case it's better to retreive each element individually. + */ +void load_im2col_block_slow( + out FPIm2ColBlock block, + const int k4, + const int m4, + const int logical_K, + const int M) { + Im2ColMatrixIdx im2col_idx_base; + im2col_idx_base.col = mul_4(k4); + im2col_idx_base.row = mul_4(m4); + + im2col_idx_base.group_idx = im2col_idx_base.col / conv2d_params.K_per_group; + im2col_idx_base.col_idx_in_group = + im2col_idx_base.col % conv2d_params.K_per_group; + + [[unroll]] for (int m_off = 0; m_off < 4; ++m_off) { + [[unroll]] for (int k_off = 0; k_off < 4; ++k_off) { + Im2ColMatrixIdx im2col_idx = im2col_idx_base; + im2col_idx.row += m_off; + im2col_idx.col_idx_in_group += k_off; + + // bounds checking + if (im2col_idx.col_idx_in_group >= conv2d_params.logical_K_per_group || + im2col_idx.row >= M) { + block.data[m_off][k_off] = T(0); + continue; + } + + TensorIndex4D input_tidx; + im2col_idx_to_input_tidx(input_tidx, im2col_idx); + + block.data[m_off][k_off] = load_input_texel_element(input_tidx); + } + } +} + +void load_im2col_block( + out FPIm2ColBlock block, + const int k4, + const int m4, + const int logical_K, + const int M) { + if (mod_4(conv2d_params.in_channels_per_group) == 0) { + load_im2col_block_fast(block, k4, m4, logical_K, M); + } else { + load_im2col_block_slow(block, k4, m4, logical_K, M); + } +} + +void load_input_im2col_tile( + out FPInputTile tile, + const int k4_start, + const int m4_start, + const int logical_K, + const int M) { + FPIm2ColBlock block; +#if TILE_K4 == 1 + [[unroll]] for (int m4 = 0; m4 < TILE_M4; ++m4) { + load_im2col_block(block, k4_start, m4_start + m4, logical_K, M); + for (int row = 0; row < 4; ++row) { + const int m = mul_4(m4) + row; + tile.data[m][0] = block.data[row]; + } + } + +#else + [[unroll]] for (int m4 = 0; m4 < TILE_M4; ++m4) { + [[unroll]] for (int k4 = 0; k4 < TILE_K4; ++k4) { + load_im2col_block(block, k4_start + k4, m4_start + m4, logical_K, M); + for (int row = 0; row < 4; ++row) { + const int m = mul_4(m4) + row; + tile.data[m][k4] = block.data[row]; + } + } + } + +#endif +} + +#endif // CONV2D_FP_IM2COL_BLOCK_LOAD diff --git a/backends/vulkan/runtime/graph/ops/glsl/conv2d_fp_im2col_block_store.glslh b/backends/vulkan/runtime/graph/ops/glsl/conv2d_fp_im2col_block_store.glslh new file mode 100644 index 00000000000..2171d75c628 --- /dev/null +++ b/backends/vulkan/runtime/graph/ops/glsl/conv2d_fp_im2col_block_store.glslh @@ -0,0 +1,68 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#ifndef CONV2D_FP_IM2COL_BLOCK_STORE +#define CONV2D_FP_IM2COL_BLOCK_STORE + +/* + * Defines utilities to store data for a 4x4 im2col output matrix block computed + * from matrix multiplication to an output image. + * + * Requires: + * - t_output to be defined in the shader layout, representing the texture of + * the output image + */ + +#extension GL_EXT_control_flow_attributes : require + +#include "common.glslh" +#include "conv2d_common.glslh" +#include "conv2d_fp_im2col_block.glslh" +#include "linear_fp_output_tile.glslh" + +// TODO: implement buffer support +void write_output_texel(const VEC4_T out_texel, const TensorIndex4D tidx) { + // Assume batch size is 1 + imageStore( + t_output, ivec3(tidx.data.x, tidx.data.y, div_4(tidx.data.z)), out_texel); +} + +void write_im2col_tile_as_image( + const FPOutTile tile, + const int n4_start, + const int m_start) { + Im2ColMatrixIdx im2col_tidx; + im2col_tidx.col = mul_4(n4_start); + im2col_tidx.row = m_start; +#if TILE_K4 == 1 + [[unroll]] for (int m = 0; m < TILE_M; ++m) { + TensorIndex4D output_tidx; + im2col_tidx_to_output_tidx(output_tidx, im2col_tidx); + + if (any(greaterThanEqual(output_tidx.data, output_sizes))) { + continue; + } + write_output_texel(tile.data[m][0], output_tidx); + im2col_tidx.row++; + } + +#else + [[unroll]] for (int m = 0; m < TILE_M; ++m) { + [[unroll]] for (int k4 = 0; k4 < TILE_K4; ++k4) { + TensorIndex4D output_tidx; + im2col_tidx_to_output_tidx(output_tidx, im2col_tidx); + + write_output_texel(tile.data[m][k4], output_tidx); + im2col_tidx.row++; + } + } + +#endif +} + +#endif // CONV2D_FP_IM2COL_BLOCK_STORE diff --git a/backends/vulkan/runtime/graph/ops/glsl/conv2d_q8csw_linear_tiled.glsl b/backends/vulkan/runtime/graph/ops/glsl/conv2d_q8csw_linear_tiled.glsl new file mode 100644 index 00000000000..e2b239800a8 --- /dev/null +++ b/backends/vulkan/runtime/graph/ops/glsl/conv2d_q8csw_linear_tiled.glsl @@ -0,0 +1,120 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#version 450 core + +#define PRECISION ${PRECISION} +#define VEC4_T ${texel_load_type(DTYPE, OUTPUT_STORAGE)} +#define T ${texel_load_component_type(DTYPE, OUTPUT_STORAGE)} + +$if OUTPUT_STORAGE == "buffer": + #define OUTPUT_BUFFER +$if INPUT_STORAGE == "buffer": + #define INPUT_BUFFER +$if WEIGHT_STORAGE == "buffer": + #define WEIGHT_BUFFER + +#define TILE_M4 ${TILE_M4} +#define TILE_K4 ${TILE_K4} +#define TILE_N4 ${TILE_N4} + +#define TILE_M ${TILE_M4 * 4} +#define TILE_K ${TILE_K4 * 4} +#define TILE_N ${TILE_N4 * 4} + +${define_required_extensions(DTYPE)} + +layout(std430) buffer; + +#include "conv2d_common.glslh" + +${layout_declare_tensor(B, "w", "t_output", DTYPE, OUTPUT_STORAGE, is_scalar_array=False)} +${layout_declare_tensor(B, "r", "t_input", DTYPE, INPUT_STORAGE, is_scalar_array=False)} +${layout_declare_tensor(B, "r", "t_packed_int8_weight", "int", WEIGHT_STORAGE, is_scalar_array=False)} +${layout_declare_tensor(B, "r", "t_weight_scales", DTYPE, "buffer", is_scalar_array=False)} +${layout_declare_tensor(B, "r", "t_bias", DTYPE, "buffer", is_scalar_array=False)} + +${layout_declare_ubo(B, "ivec4", "output_sizes")} +${layout_declare_ubo(B, "ivec4", "input_sizes")} +${layout_declare_ubo(B, "Conv2DParams", "conv2d_params")} + +layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in; + +${layout_declare_spec_const(C, "int", "apply_bias", "1")} + +#include "linear_fp_input_tile_load.glslh" +#include "linear_int8_weight_tile_load.glslh" +#include "linear_fp_weight_scales_load.glslh" +#include "linear_fp_bias_load.glslh" +#include "linear_fp_output_tile_fp_int8_compute.glslh" +#include "linear_fp_output_tile_fp_compute.glslh" +#include "conv2d_fp_im2col_block_store.glslh" + +void main() { + // Each thread writes out a 4 wide x 4 high tile of output values + const int out_tile_x = int(gl_GlobalInvocationID.x); + const int out_tile_y = int(gl_GlobalInvocationID.y); + + const int n = int(out_tile_x * TILE_N); + const int m = int(out_tile_y * TILE_M); + + const int n4 = div_4(n); + const int m4 = div_4(m); + + // M = flattened output width, height, batches dims + const int M = output_sizes.x * output_sizes.y * output_sizes.w; + // N = output channels + const int N = output_sizes.z; + + if (n >= N || m >= M) { + return; + } + + const int group_idx = n / conv2d_params.out_channels_per_group; + const int input_k4_offset = conv2d_params.K4_per_group * group_idx; + + const int K4 = conv2d_params.K4; + const int N4 = div_up_4(N); + + FPOutTile out_tile; + initialize(out_tile); + + FPInputTile in_tile; + Int8WeightTile int8_weight_tile; + + const bool dont_check_bounds = (M - m) >= TILE_M; + + if (dont_check_bounds) { + for (int k4 = 0; k4 < conv2d_params.K4_per_group; k4++) { + load_input_tile_no_checks(in_tile, k4 + input_k4_offset, m, K4, M); + load_int8_weight_tile(int8_weight_tile, n4, k4, N4); + fp_accumulate_with_int8_weight(out_tile, in_tile, int8_weight_tile); + } + } else { + for (int k4 = 0; k4 < conv2d_params.K4_per_group; k4++) { + load_input_tile_with_checks(in_tile, k4 + input_k4_offset, m, K4, M); + load_int8_weight_tile(int8_weight_tile, n4, k4, N4); + fp_accumulate_with_int8_weight(out_tile, in_tile, int8_weight_tile); + } + } + + FPPerOutChannelParams weight_scales_tile; + load_weight_scales_tile(weight_scales_tile, n4); + + if (apply_bias > 0) { + FPPerOutChannelParams bias_tile; + load_bias_tile(bias_tile, n4); + + apply_scales_and_biases(out_tile, weight_scales_tile, bias_tile); + } + else { + apply_scales(out_tile, weight_scales_tile); + } + + write_im2col_tile_as_image(out_tile, n4, m); +} diff --git a/backends/vulkan/runtime/graph/ops/glsl/conv2d_q8csw_linear_tiled.yaml b/backends/vulkan/runtime/graph/ops/glsl/conv2d_q8csw_linear_tiled.yaml new file mode 100644 index 00000000000..9b3b5aa2c0a --- /dev/null +++ b/backends/vulkan/runtime/graph/ops/glsl/conv2d_q8csw_linear_tiled.yaml @@ -0,0 +1,22 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +conv2d_q8csw_linear_tiled: + parameter_names_with_default_values: + DTYPE: float + OUTPUT_STORAGE: texture3d + INPUT_STORAGE: buffer + WEIGHT_STORAGE: texture2d + TILE_M4: 1 + TILE_N4: 1 + TILE_K4: 1 + generate_variant_forall: + DTYPE: + - VALUE: float + shader_variants: + - NAME: conv2d_q8csw_linear_tiled_texture3d_buffer_texture2d + - NAME: conv2d_q8csw_linear_tiled_texture3d_buffer_buffer + WEIGHT_STORAGE: buffer diff --git a/backends/vulkan/runtime/graph/ops/glsl/conv2d_q8ta_q8csw_linear_tiled.glsl b/backends/vulkan/runtime/graph/ops/glsl/conv2d_q8ta_q8csw_linear_tiled.glsl new file mode 100644 index 00000000000..f74a1311095 --- /dev/null +++ b/backends/vulkan/runtime/graph/ops/glsl/conv2d_q8ta_q8csw_linear_tiled.glsl @@ -0,0 +1,137 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#version 450 core + +#define PRECISION ${PRECISION} +#define VEC4_T ${texel_load_type(DTYPE, OUTPUT_STORAGE)} +#define T ${texel_load_component_type(DTYPE, OUTPUT_STORAGE)} + +$if OUTPUT_STORAGE == "buffer": + #define OUTPUT_BUFFER +$if PACKED_INT8_INPUT_STORAGE == "buffer": + #define PACKED_INT8_INPUT_BUFFER +$if WEIGHT_STORAGE == "buffer": + #define WEIGHT_BUFFER + +#define TILE_M4 ${TILE_M4} +#define TILE_K4 ${TILE_K4} +#define TILE_N4 ${TILE_N4} + +#define TILE_M ${TILE_M4 * 4} +#define TILE_K ${TILE_K4 * 4} +#define TILE_N ${TILE_N4 * 4} + +${define_required_extensions(DTYPE)} + +#extension GL_EXT_integer_dot_product : require + +layout(std430) buffer; + +#include "conv2d_common.glslh" + +${layout_declare_tensor(B, "w", "t_output", DTYPE, OUTPUT_STORAGE, is_scalar_array=False)} +${layout_declare_tensor(B, "r", "t_packed_int8_input", "int", PACKED_INT8_INPUT_STORAGE, is_scalar_array=False)} +${layout_declare_tensor(B, "r", "t_packed_int8_weight", "int", WEIGHT_STORAGE, is_scalar_array=False)} +${layout_declare_tensor(B, "r", "t_weight_sums", "int", "buffer", is_scalar_array=False)} +${layout_declare_tensor(B, "r", "t_weight_scales", DTYPE, "buffer", is_scalar_array=False)} +${layout_declare_tensor(B, "r", "t_bias", DTYPE, "buffer", is_scalar_array=False)} + +${layout_declare_ubo(B, "ivec4", "output_sizes")} +${layout_declare_ubo(B, "ivec4", "input_sizes")} +${layout_declare_ubo(B, "Conv2DParams", "conv2d_params")} + +layout(push_constant) uniform restrict Block { + float input_scale; + int input_zp; +}; + +layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in; + +${layout_declare_spec_const(C, "int", "apply_bias", "1")} + +#include "linear_int8_input_tile_load.glslh" +#include "linear_int8_weight_tile_load.glslh" +#include "linear_int_weight_sums_load.glslh" +#include "linear_fp_weight_scales_load.glslh" +#include "linear_fp_output_tile_int8_int8_compute.glslh" +#include "linear_fp_bias_load.glslh" +#include "conv2d_fp_im2col_block_store.glslh" + +void main() { + // Each thread writes out a 4 wide x 4 high tile of output values + const int out_tile_x = int(gl_GlobalInvocationID.x); + const int out_tile_y = int(gl_GlobalInvocationID.y); + + const int n = int(out_tile_x * TILE_N); + const int m = int(out_tile_y * TILE_M); + + const int n4 = div_4(n); + const int m4 = div_4(m); + + // M = flattened output width, height, batches dims + const int M = output_sizes.x * output_sizes.y * output_sizes.w; + // N = output channels + const int N = output_sizes.z; + + if (n >= N || m >= M) { + return; + } + + const int group_idx = n / conv2d_params.out_channels_per_group; + const int input_k4_offset = conv2d_params.K4_per_group * group_idx; + + const int K4 = conv2d_params.K4; + const int N4 = div_up_4(N); + + Int32Accum out_accum; + initialize(out_accum); + + Int8InputTile int8_in_tile; + Int8WeightTile int8_weight_tile; + + for (int k4 = 0; k4 < conv2d_params.K4_per_group; k4++) { + load_int8_input_tile(int8_in_tile, k4 + input_k4_offset, m4, K4); + load_int8_weight_tile(int8_weight_tile, n4, k4, N4); + + int_accumulate_with_int8_weight(out_accum, int8_in_tile, int8_weight_tile); + } + + FPPerOutChannelParams weight_scales_tile; + load_weight_scales_tile(weight_scales_tile, n4); + + IntPerOutChannelParams weight_sums_tile; + load_weight_sums_tile(weight_sums_tile, n4); + + FPOutTile out_tile; + initialize(out_tile); + if (apply_bias > 0) { + FPPerOutChannelParams bias_tile; + load_bias_tile(bias_tile, int(n4)); + + accumulate_out_tile_with_int_accum( + out_tile, + out_accum, + input_scale, + input_zp, + weight_sums_tile, + weight_scales_tile, + bias_tile); + } + else { + accumulate_out_tile_with_int_accum( + out_tile, + out_accum, + input_scale, + input_zp, + weight_sums_tile, + weight_scales_tile); + } + + write_im2col_tile_as_image(out_tile, n4, m); +} diff --git a/backends/vulkan/runtime/graph/ops/glsl/conv2d_q8ta_q8csw_linear_tiled.yaml b/backends/vulkan/runtime/graph/ops/glsl/conv2d_q8ta_q8csw_linear_tiled.yaml new file mode 100644 index 00000000000..629001765c1 --- /dev/null +++ b/backends/vulkan/runtime/graph/ops/glsl/conv2d_q8ta_q8csw_linear_tiled.yaml @@ -0,0 +1,22 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +conv2d_q8ta_q8csw_linear_tiled: + parameter_names_with_default_values: + DTYPE: float + OUTPUT_STORAGE: texture3d + PACKED_INT8_INPUT_STORAGE: buffer + WEIGHT_STORAGE: texture2d + TILE_M4: 1 + TILE_N4: 1 + TILE_K4: 1 + generate_variant_forall: + DTYPE: + - VALUE: float + shader_variants: + - NAME: conv2d_q8ta_q8csw_linear_tiled_texture3d_buffer_texture2d + - NAME: conv2d_q8ta_q8csw_linear_tiled_texture3d_buffer_buffer + WEIGHT_STORAGE: buffer diff --git a/backends/vulkan/runtime/graph/ops/glsl/im2col.glsl b/backends/vulkan/runtime/graph/ops/glsl/im2col.glsl new file mode 100644 index 00000000000..f045d4e9702 --- /dev/null +++ b/backends/vulkan/runtime/graph/ops/glsl/im2col.glsl @@ -0,0 +1,110 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#version 450 core + +#extension GL_EXT_debug_printf : enable + +#define PRECISION ${PRECISION} +#define VEC4_T ${texel_load_type(DTYPE, INPUT_STORAGE)} +#define T ${texel_load_component_type(DTYPE, INPUT_STORAGE)} + +$if OUTPUT_STORAGE == "buffer": + #define OUTPUT_BUFFER +$if INPUT_STORAGE == "buffer": + #define INPUT_BUFFER + +#define TILE_M4 1 +#define TILE_N4 1 +#define TILE_K4 1 + +#define TILE_M 4 +#define TILE_N 4 +#define TILE_K 4 + +${define_required_extensions(DTYPE)} + +layout(std430) buffer; + +#include "conv2d_common.glslh" + +${layout_declare_tensor(B, "w", "t_output", DTYPE, OUTPUT_STORAGE, is_scalar_array=False)} +${layout_declare_tensor(B, "r", "t_input", DTYPE, INPUT_STORAGE, is_scalar_array=False)} + +// Sizes of the im2col matrix of the convolution input +${layout_declare_ubo(B, "ivec4", "matrix_sizes")} +// Sizes of the input image +${layout_declare_ubo(B, "ivec4", "input_sizes")} +// Sizes of the output image +${layout_declare_ubo(B, "ivec4", "output_sizes")} + +${layout_declare_ubo(B, "Conv2DParams", "conv2d_params")} + +layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in; + +#include "conv2d_fp_im2col_block_load.glslh" + +#ifdef OUTPUT_BUFFER + +void write_tile( + const FPInputTile in_tile, + const int k4, + const int m_start, + const int K4) { + [[unroll]] for (int m = 0; m < TILE_M; m++) { + t_output[(m_start + m) * K4 + k4] = in_tile.data[m][0]; + } +} + +#else // OUTPUT_TEXTURE + +void write_tile( + const FPInputTile in_tile, + const int k4, + const int m_start, + const int K4) { + [[unroll]] for (int m = 0; m < TILE_M; m++) { + imageStore(t_output, ivec3(k4, m_start + m, 0), vec4(in_tile.data[m][0])); + } +} + +#endif // OUTPUT_BUFFER + +void main() { + // Each thread writes out a 4 wide x 4 high block of the output matrix. The + // thread position corresponds to the block index. + const int k4 = int(gl_GlobalInvocationID.x); + const int m4 = int(gl_GlobalInvocationID.y); + + // Convert block idx to tensor idx + const int k = mul_4(k4); + const int m = mul_4(m4); + + const int in_channels_per_group = input_sizes.z / conv2d_params.groups; + + // Logical K dim size (unpadded) + const int logical_K = conv2d_params.logical_K; + // Physical K dim, which contains padding elements + const int K = matrix_sizes.x; + + // M dim, which represents the number of flattened output width, height, + // batches. Unlike K, there is no difference between the physical and logical + // sizes. + const int M = matrix_sizes.y; + + if (k >= K || m >= M) { + return; + } + + FPInputTile in_tile; + load_input_im2col_tile(in_tile, k4, m4, logical_K, M); + + // Number of texels in the x dim of the output matrix + const int K4 = div_4(K); + write_tile(in_tile, k4, m, K4); +} diff --git a/backends/vulkan/runtime/graph/ops/glsl/im2col.yaml b/backends/vulkan/runtime/graph/ops/glsl/im2col.yaml new file mode 100644 index 00000000000..dd486b0e1a6 --- /dev/null +++ b/backends/vulkan/runtime/graph/ops/glsl/im2col.yaml @@ -0,0 +1,18 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +im2col: + parameter_names_with_default_values: + DTYPE: float + OUTPUT_STORAGE: buffer + INPUT_STORAGE: texture3d + generate_variant_forall: + DTYPE: + - VALUE: float + shader_variants: + - NAME: im2col_buffer_texture3d + - NAME: im2col_texture3d_texture3d + OUTPUT_STORAGE: texture3d diff --git a/backends/vulkan/runtime/graph/ops/glsl/quantize_and_pack_im2col.glsl b/backends/vulkan/runtime/graph/ops/glsl/quantize_and_pack_im2col.glsl new file mode 100644 index 00000000000..450d6376537 --- /dev/null +++ b/backends/vulkan/runtime/graph/ops/glsl/quantize_and_pack_im2col.glsl @@ -0,0 +1,89 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#version 450 core + +#define PRECISION ${PRECISION} +#define VEC4_T ${texel_load_type(DTYPE, INPUT_STORAGE)} +#define T ${texel_load_component_type(DTYPE, INPUT_STORAGE)} + +$if OUTPUT_STORAGE == "buffer": + #define OUTPUT_BUFFER +$if INPUT_STORAGE == "buffer": + #define INPUT_BUFFER + +#define TILE_M4 1 +#define TILE_N4 1 +#define TILE_K4 1 + +#define TILE_M 4 +#define TILE_N 4 +#define TILE_K 4 + +${define_required_extensions(DTYPE)} + +layout(std430) buffer; + +#include "conv2d_common.glslh" + +${layout_declare_tensor(B, "w", "t_packed_int8_input", "int", OUTPUT_STORAGE, is_scalar_array=False)} +${layout_declare_tensor(B, "r", "t_input", DTYPE, INPUT_STORAGE, is_scalar_array=False)} + +// Sizes of the im2col matrix of the convolution input +${layout_declare_ubo(B, "ivec4", "matrix_sizes")} +// Sizes of the input image +${layout_declare_ubo(B, "ivec4", "input_sizes")} +// Sizes of the output image +${layout_declare_ubo(B, "ivec4", "output_sizes")} + +${layout_declare_ubo(B, "Conv2DParams", "conv2d_params")} + +layout(push_constant) uniform restrict Block { + float inv_scale; + int zp; +}; + +layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in; + +#include "conv2d_fp_im2col_block_load.glslh" +#include "linear_int8_input_block.glslh" + +void main() { + // The quantized and packed im2col matrix can be conceptualized as a 2D matrix + // with K/4 columns and M/4 rows. Each element of the matrix is a ivec4 which + // contains packed data for a 4 wide x 4 high block of the original im2col + // matrix. Each shader invocation works on writing out one ivec4, i.e. one + // block of the quantized and packed matrix. + + // Thread id corresponds to the block index + const int k4 = int(gl_GlobalInvocationID.x); + const int m4 = int(gl_GlobalInvocationID.y); + + // Convert block idx to tensor idx + const int k = mul_4(k4); + const int m = mul_4(m4); + + const int logical_K = conv2d_params.logical_K; + // Similarly, compute the logical size of the M dim. + const int logical_M = output_sizes.x * output_sizes.y * output_sizes.w; + + // Check if tensor indices are out of bounds + if (k >= logical_K || m >= logical_M) { + return; + } + + FPInputTile in_tile; + load_input_im2col_tile(in_tile, k4, m4, logical_K, logical_M); + + Int8InputBlock packed_block; + quantize_and_pack(packed_block, in_tile, inv_scale, zp); + + // Number of texels in the x dim of the output matrix + const int K4 = div_4(matrix_sizes.x); + write_block(packed_block, k4, m4, K4); +} diff --git a/backends/vulkan/runtime/graph/ops/glsl/quantize_and_pack_im2col.yaml b/backends/vulkan/runtime/graph/ops/glsl/quantize_and_pack_im2col.yaml new file mode 100644 index 00000000000..93f8269d607 --- /dev/null +++ b/backends/vulkan/runtime/graph/ops/glsl/quantize_and_pack_im2col.yaml @@ -0,0 +1,18 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +quantize_and_pack_im2col: + parameter_names_with_default_values: + DTYPE: float + OUTPUT_STORAGE: buffer + INPUT_STORAGE: texture3d + generate_variant_forall: + DTYPE: + - VALUE: float + shader_variants: + - NAME: quantize_and_pack_im2col_buffer_texture3d + - NAME: quantize_and_pack_im2col_texture3d_texture3d + OUTPUT_STORAGE: texture3d diff --git a/backends/vulkan/runtime/graph/ops/impl/QuantizedConvolution.cpp b/backends/vulkan/runtime/graph/ops/impl/QuantizedConvolution.cpp new file mode 100644 index 00000000000..51f8138485e --- /dev/null +++ b/backends/vulkan/runtime/graph/ops/impl/QuantizedConvolution.cpp @@ -0,0 +1,695 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include + +#include +#include +#include +#include + +namespace vkcompute { + +// +// Utility functions +// + +struct Conv2DParams { + utils::ivec2 kernel_size; + utils::ivec2 stride; + utils::ivec2 padding; + utils::ivec2 dilation; + int32_t groups; + int32_t out_channels_per_group; + int32_t in_channels_per_group; + int32_t logical_K_per_group; + int32_t K_per_group; + int32_t K4_per_group; + int32_t logical_K; + int32_t K; + int32_t K4; +}; + +Conv2DParams create_conv2d_params( + ComputeGraph& graph, + const ValueRef& conv_input, + const ValueRef& conv_output, + const ValueRef& kernel_size, + const ValueRef& stride, + const ValueRef& padding, + const ValueRef& dilation, + const ValueRef& groups) { + const auto kernel_size_list = graph.get_int_list(kernel_size); + const auto stride_list = graph.get_int_list(stride); + const auto padding_list = graph.get_int_list(padding); + const auto dilation_list = graph.get_int_list(dilation); + const int32_t groups_val = graph.get_int(groups); + + // Pre-compute input and output channels per group + + std::vector out_sizes = graph.sizes_of(conv_output); + const int32_t out_channels = utils::val_at(-3, out_sizes); + const int32_t out_channels_per_group = out_channels / groups_val; + + std::vector in_sizes = graph.sizes_of(conv_input); + const int32_t in_channels = utils::val_at(-3, in_sizes); + const int32_t in_channels_per_group = in_channels / groups_val; + + // Pre-compute the number of elements along the K dimension per group. This + // quantity is aligned to the next multiple of 4 to ensure data loads are + // aligned to texel boundaries. + + const int32_t logical_K_per_group = + kernel_size_list->at(0) * kernel_size_list->at(1) * in_channels_per_group; + const int32_t K_per_group = utils::align_up_4(logical_K_per_group); + const int32_t K4_per_group = K_per_group / 4; + + // Pre-compute the "theoretical" size of the K dim of the input im2col matrix, + // which represents the flattened convolution window used to compute an output + // element. This is used for bounds checking. + + const int32_t logical_K = + kernel_size_list->at(0) * kernel_size_list->at(1) * in_channels; + + const int32_t K = K_per_group * groups_val; + // Used for texel stride calculations + const int32_t K4 = K / 4; + + return Conv2DParams{ + // Swap the order from HW to WH + utils::make_ivec2({kernel_size_list->at(1), kernel_size_list->at(0)}), + utils::make_ivec2({stride_list->at(1), stride_list->at(0)}), + utils::make_ivec2({padding_list->at(1), padding_list->at(0)}), + utils::make_ivec2({dilation_list->at(1), dilation_list->at(0)}), + groups_val, + out_channels_per_group, + in_channels_per_group, + logical_K_per_group, + K_per_group, + K4_per_group, + logical_K, + K, + K4, + }; +} + +std::vector calculate_input_im2col_sizes( + ComputeGraph* graph, + const ValueRef& input, + const ValueRef& output, + const ValueRef& kernel_size, + const ValueRef& groups) { + std::vector in_sizes = graph->sizes_of(input); + const int64_t in_channels = utils::val_at(-3, in_sizes); + + std::vector out_sizes = graph->sizes_of(output); + const int64_t batches = utils::val_at(-4, out_sizes); + const int64_t out_height = utils::val_at(-2, out_sizes); + const int64_t out_width = utils::val_at(-1, out_sizes); + + // Represents the number of channel groups + const int64_t groups_val = graph->extract_scalar(groups); + // No need to div_up because in_channels % groups_val = 0 + const int64_t in_channels_per_group = in_channels / groups_val; + + const auto kernel_size_list = graph->get_int_list(kernel_size); + + // Align to the next multiple of 4 to ensure that data loads align nicely with + // texel boundaries. We want to ensure that the first data element of each + // group is at the start of its texel. + const int64_t flattened_kernel_len = utils::align_up_4( + in_channels_per_group * kernel_size_list->at(0) * + kernel_size_list->at(1)); + + // K -> flattened convolution window (adjusted) + const int64_t K = flattened_kernel_len * groups_val; + // M -> number of elements in 2D output plane. This is aligned to the next + // multiple of 4 since the im2col shader operates on 4x4 blocks. + const int64_t M = utils::align_up_4(out_height * out_width * batches); + + return {M, K}; +} + +std::vector calculate_output_im2col_sizes( + ComputeGraph* graph, + const ValueRef& output) { + std::vector out_sizes = graph->sizes_of(output); + const int64_t batches = utils::val_at(-4, out_sizes); + const int64_t out_channels = utils::val_at(-3, out_sizes); + const int64_t out_height = utils::val_at(-2, out_sizes); + const int64_t out_width = utils::val_at(-1, out_sizes); + + // N -> output channels + const int64_t N = out_channels; + // M -> number of elements in 2D output plane + const int64_t M = out_height * out_width * batches; + + return {M, N}; +} + +// +// Shader dispatch utilities +// + +utils::uvec3 im2col_global_wg_size( + ComputeGraph* graph, + const vkapi::ShaderInfo& shader, + const std::vector& args, + const std::vector& resize_args) { + const ValueRef input_image = args.at(1).refs.at(0); + const ValueRef output_image = resize_args.at(0); + const ValueRef kernel_size = resize_args.at(1); + const ValueRef groups = resize_args.at(2); + + std::vector im2col_sizes = calculate_input_im2col_sizes( + graph, input_image, output_image, kernel_size, groups); + const uint32_t K = utils::safe_downcast(im2col_sizes[1]); + const uint32_t M = utils::safe_downcast(im2col_sizes[0]); + + // 1 output tile is 4x4 elements + const uint32_t K4 = utils::div_up(K, 4u); + const uint32_t M4 = utils::div_up(M, 4u); + + return {K4, M4, 1}; +} + +utils::uvec3 col2im_global_wg_size( + ComputeGraph* graph, + const vkapi::ShaderInfo& shader, + const std::vector& args, + const std::vector& resize_args) { + const ValueRef output = args.at(0).refs.at(0); + + std::vector im2col_sizes = + calculate_output_im2col_sizes(graph, output); + const uint32_t N = utils::safe_downcast(im2col_sizes[1]); + const uint32_t M = utils::safe_downcast(im2col_sizes[0]); + + // 1 output tile is 4x4 elements + const uint32_t N4 = utils::div_up(N, 4u); + const uint32_t M4 = utils::div_up(M, 4u); + + return {N4, M4, 1}; +} + +// +// Dispatch nodes +// + +void add_input_im2col_node( + ComputeGraph& graph, + const ValueRef input_image, + const ValueRef kernel_size, + const ValueRef stride, + const ValueRef padding, + const ValueRef dilation, + const ValueRef groups, + const ValueRef output_image, + const ValueRef input_im2col) { + Conv2DParams conv_params = create_conv2d_params( + graph, + input_image, + output_image, + kernel_size, + stride, + padding, + dilation, + groups); + + std::string kernel_name = "im2col"; + add_storage_type_suffix(kernel_name, graph.storage_type_of(input_im2col)); + add_storage_type_suffix(kernel_name, graph.storage_type_of(input_image)); + add_dtype_suffix(kernel_name, graph.dtype_of(output_image)); + + vkapi::ParamsBindList param_buffers = { + graph.sizes_ubo(input_im2col), + graph.sizes_ubo(input_image), + graph.sizes_ubo(output_image), + graph.create_params_buffer(conv_params)}; + + graph.execute_nodes().emplace_back(new DynamicDispatchNode( + graph, + VK_KERNEL_FROM_STR(kernel_name), + im2col_global_wg_size, + default_pick_local_wg_size, + // Inputs and Outputs + {{input_im2col, vkapi::kWrite}, {input_image, vkapi::kRead}}, + // Shader params buffers + param_buffers, + // Push Constants + {}, + // Specialization Constants + {}, + // Resize args + {output_image, kernel_size, groups}, + // Resizing Logic + nullptr)); +} + +void add_quantize_and_pack_im2col_node( + ComputeGraph& graph, + const ValueRef input_image, + const ValueRef input_scale, + const ValueRef input_zp, + const ValueRef kernel_size, + const ValueRef stride, + const ValueRef padding, + const ValueRef dilation, + const ValueRef groups, + const ValueRef output_image, + const ValueRef input_int_im2col) { + Conv2DParams conv_params = create_conv2d_params( + graph, + input_image, + output_image, + kernel_size, + stride, + padding, + dilation, + groups); + + float inv_scale = 1.0f / graph.extract_scalar(input_scale); + int32_t zp = graph.extract_scalar(input_zp); + + // Get shader for quantized conv2d linear tiled + std::string kernel_name = "quantize_and_pack_im2col"; + add_storage_type_suffix(kernel_name, graph.storage_type_of(input_int_im2col)); + add_storage_type_suffix(kernel_name, graph.storage_type_of(input_image)); + add_dtype_suffix(kernel_name, graph.dtype_of(output_image)); + + vkapi::ShaderInfo shader = VK_KERNEL_FROM_STR(kernel_name); + + vkapi::ParamsBindList param_buffers = { + graph.sizes_ubo(input_int_im2col), + graph.sizes_ubo(input_image), + graph.sizes_ubo(output_image), + graph.create_params_buffer(conv_params)}; + + std::vector push_constants = { + PushConstantDataInfo(&inv_scale, sizeof(inv_scale)), + PushConstantDataInfo(&zp, sizeof(zp)), + }; + + graph.execute_nodes().emplace_back(new DynamicDispatchNode( + graph, + VK_KERNEL_FROM_STR(kernel_name), + im2col_global_wg_size, + default_pick_local_wg_size, + // Inputs and Outputs + {{input_int_im2col, vkapi::kWrite}, {input_image, vkapi::kRead}}, + // Shader params buffers + param_buffers, + // Push Constants + push_constants, + // Specialization Constants + {}, + // Resize args + {output_image, kernel_size, groups}, + // Resizing Logic + nullptr)); +} + +void add_conv2d_q8csw_linear_node( + ComputeGraph& graph, + const ValueRef input_im2col, + const ValueRef input_image, + const ValueRef packed_weight, + const ValueRef packed_weight_scales, + const ValueRef bias_data, + const ValueRef packed_bias, + const ValueRef kernel_size, + const ValueRef stride, + const ValueRef padding, + const ValueRef dilation, + const ValueRef groups, + const ValueRef output_image) { + Conv2DParams conv_params = create_conv2d_params( + graph, + input_image, + output_image, + kernel_size, + stride, + padding, + dilation, + groups); + + // One limitation of the current implementation is that for grouped convs, + // the number of output_image channels per group must be a multiple of 4. One + // loaded 4x4 weight tile must all belong to the same group. + if (conv_params.groups > 1) { + VK_CHECK_COND(conv_params.out_channels_per_group % 4 == 0); + } + + std::string kernel_name = "conv2d_q8csw_linear_tiled"; + add_storage_type_suffix(kernel_name, graph.storage_type_of(output_image)); + add_storage_type_suffix(kernel_name, graph.storage_type_of(input_im2col)); + add_storage_type_suffix(kernel_name, graph.storage_type_of(packed_weight)); + add_dtype_suffix(kernel_name, graph.dtype_of(output_image)); + vkapi::ShaderInfo shader = VK_KERNEL_FROM_STR(kernel_name); + + vkapi::ParamsBindList param_buffers = { + graph.sizes_ubo(output_image), + graph.sizes_ubo(input_image), + graph.create_params_buffer(conv_params)}; + + uint32_t apply_bias = 1; + if (graph.val_is_none(bias_data)) { + apply_bias = 0; + } + + graph.execute_nodes().emplace_back(new DynamicDispatchNode( + graph, + VK_KERNEL_FROM_STR(kernel_name), + col2im_global_wg_size, + quantized_linear_local_wg_size, + // Inputs and Outputs + {{output_image, vkapi::kWrite}, + {{input_im2col, packed_weight, packed_weight_scales, packed_bias}, + vkapi::kRead}}, + // Shader params buffers + param_buffers, + // Push Constants + {}, + // Specialization Constants + {apply_bias}, + // Resize args + {}, + // Resizing Logic + nullptr)); +} + +void add_conv2d_q8ta_q8csw_linear_node( + ComputeGraph& graph, + const ValueRef input_int_im2col, + const ValueRef input_image, + const ValueRef input_scale, + const ValueRef input_zp, + const ValueRef weight_data, + const ValueRef packed_weight, + const ValueRef packed_weight_sums, + const ValueRef packed_weight_scales, + const ValueRef bias_data, + const ValueRef packed_bias, + const ValueRef kernel_size, + const ValueRef stride, + const ValueRef padding, + const ValueRef dilation, + const ValueRef groups, + const ValueRef output_image) { + Conv2DParams conv_params = create_conv2d_params( + graph, + input_image, + output_image, + kernel_size, + stride, + padding, + dilation, + groups); + + // One limitation of the current implementation is that for grouped convs, + // the number of output channels per group must be a multiple of 4. One loaded + // 4x4 weight tile must all belong to the same group. + if (conv_params.groups > 1) { + VK_CHECK_COND(conv_params.out_channels_per_group % 4 == 0); + } + + float scale = graph.extract_scalar(input_scale); + int32_t zp = graph.extract_scalar(input_zp); + + std::string kernel_name = "conv2d_q8ta_q8csw_linear_tiled"; + add_storage_type_suffix(kernel_name, graph.storage_type_of(output_image)); + add_storage_type_suffix(kernel_name, graph.storage_type_of(input_int_im2col)); + add_storage_type_suffix(kernel_name, graph.storage_type_of(packed_weight)); + add_dtype_suffix(kernel_name, graph.dtype_of(output_image)); + vkapi::ShaderInfo shader = VK_KERNEL_FROM_STR(kernel_name); + + vkapi::ParamsBindList param_buffers = { + graph.sizes_ubo(output_image), + graph.sizes_ubo(input_image), + graph.create_params_buffer(conv_params)}; + + std::vector push_constants = { + PushConstantDataInfo(&scale, sizeof(scale)), + PushConstantDataInfo(&zp, sizeof(zp)), + }; + + uint32_t apply_bias = 1; + if (graph.val_is_none(bias_data)) { + apply_bias = 0; + } + + graph.execute_nodes().emplace_back(new DynamicDispatchNode( + graph, + VK_KERNEL_FROM_STR(kernel_name), + col2im_global_wg_size, + quantized_linear_local_wg_size, + // Inputs and Outputs + {{output_image, vkapi::kWrite}, + {{input_int_im2col, + packed_weight, + packed_weight_sums, + packed_weight_scales, + packed_bias}, + vkapi::kRead}}, + // Shader params buffers + param_buffers, + // Push Constants + push_constants, + // Specialization Constants + {apply_bias}, + // Resize args + {weight_data}, + // Resizing Logic + nullptr)); +} + +// +// High level operator impl +// + +void quantized_conv2d_impl( + ComputeGraph& graph, + const QuantizationConfig& input_quant_config, + const QuantizationConfig& weight_quant_config, + const ValueRef input_image, + const ValueRef input_scale, + const ValueRef input_zp, + const ValueRef weight_data, + const ValueRef weight_sums_data, + const ValueRef weight_scales_data, + const ValueRef bias_data, + const ValueRef kernel_size, + const ValueRef stride, + const ValueRef padding, + const ValueRef dilation, + const ValueRef groups, + const ValueRef output_image) { + VK_CHECK_COND(weight_quant_config.granularity == kPerChannel); + VK_CHECK_COND(weight_quant_config.nbits == 8); + VK_CHECK_COND(weight_quant_config.is_symmetric); + + const ValueRef packed_weight = + prepack_quantized_linear_weight(graph, weight_quant_config, weight_data); + ValueRef packed_weight_scales = prepack_standard( + graph, weight_scales_data, utils::kBuffer, utils::kWidthPacked); + + // Create a dummy tensor to fill the binding slot of the bias tensor if it is + // not provided. This helps simplify dispatch logic and makes it so that + // fewer shader variants need to be generated. + TmpTensor dummy_bias( + &graph, + {}, + graph.dtype_of(output_image), + utils::kBuffer, + utils::kWidthPacked); + + ValueRef packed_bias = dummy_bias.vref; + if (!graph.val_is_none(bias_data)) { + packed_bias = + prepack_standard(graph, bias_data, utils::kBuffer, utils::kWidthPacked); + } + + std::vector input_im2col_sizes = calculate_input_im2col_sizes( + &graph, input_image, output_image, kernel_size, groups); + + // Use weight only quantized conv2d if at least one is true: + // 1. Device does not support int8 dot product + // 2. Input is not quantized + if (!graph.can_use_int8_dot_product() || + input_quant_config.granularity == kNoQuantization) { + TmpTensor input_im2col( + &graph, + input_im2col_sizes, + vkapi::kFloat, + utils::kBuffer, + utils::kWidthPacked); + + add_input_im2col_node( + graph, + input_image, + kernel_size, + stride, + padding, + dilation, + groups, + output_image, + input_im2col); + + add_conv2d_q8csw_linear_node( + graph, + input_im2col, + input_image, + packed_weight, + packed_weight_scales, + bias_data, + packed_bias, + kernel_size, + stride, + padding, + dilation, + groups, + output_image); + return; + } else { + // Otherwise, use activation + weight quantized conv2d + VK_CHECK_COND(input_quant_config.granularity == kPerTensor); + VK_CHECK_COND(weight_quant_config.nbits == 8); + VK_CHECK_COND(!weight_quant_config.is_dynamic); + + ValueRef packed_weight_sums = prepack_standard( + graph, weight_sums_data, utils::kBuffer, utils::kWidthPacked); + + // Allocate quantized + packed im2col matrix for input + const int64_t num_blocks_M = utils::div_up_4(input_im2col_sizes.at(0)); + const int64_t num_blocks_K = utils::div_up_4(input_im2col_sizes.at(1)); + + TmpTensor input_int_im2col( + &graph, + {num_blocks_M, num_blocks_K * 4}, + vkapi::kInt, + utils::kBuffer, + utils::kWidthPacked); + + add_quantize_and_pack_im2col_node( + graph, + input_image, + input_scale, + input_zp, + kernel_size, + stride, + padding, + dilation, + groups, + output_image, + input_int_im2col); + + add_conv2d_q8ta_q8csw_linear_node( + graph, + input_int_im2col, + input_image, + input_scale, + input_zp, + weight_data, + packed_weight, + packed_weight_sums, + packed_weight_scales, + bias_data, + packed_bias, + kernel_size, + stride, + padding, + dilation, + groups, + output_image); + return; + }; +} + +void conv2d_q8ta_q8csw(ComputeGraph& graph, const std::vector& args) { + int32_t idx = 0; + const ValueRef input_image = args.at(idx++); + const ValueRef input_scale = args.at(idx++); + const ValueRef input_zp = args.at(idx++); + const ValueRef weight_data = args.at(idx++); + const ValueRef weight_sums_data = args.at(idx++); + const ValueRef weight_scales_data = args.at(idx++); + const ValueRef bias_data = args.at(idx++); + const ValueRef kernel_size = args.at(idx++); + const ValueRef stride = args.at(idx++); + const ValueRef padding = args.at(idx++); + const ValueRef dilation = args.at(idx++); + const ValueRef groups = args.at(idx++); + const ValueRef output_image = args.at(idx++); + + const int64_t K = graph.size_at(-1, weight_data); + + QuantizationConfig input_quant_config(8, kPerTensor, {}, false); + QuantizationConfig weight_quant_config(8, kPerChannel, {K}); + + quantized_conv2d_impl( + graph, + input_quant_config, + weight_quant_config, + input_image, + input_scale, + input_zp, + weight_data, + weight_sums_data, + weight_scales_data, + bias_data, + kernel_size, + stride, + padding, + dilation, + groups, + output_image); +} + +void conv2d_q8csw(ComputeGraph& graph, const std::vector& args) { + int32_t idx = 0; + const ValueRef input_image = args.at(idx++); + const ValueRef weight_data = args.at(idx++); + const ValueRef weight_scales_data = args.at(idx++); + const ValueRef bias_data = args.at(idx++); + const ValueRef kernel_size = args.at(idx++); + const ValueRef stride = args.at(idx++); + const ValueRef padding = args.at(idx++); + const ValueRef dilation = args.at(idx++); + const ValueRef groups = args.at(idx++); + const ValueRef output_image = args.at(idx++); + + const int64_t K = graph.size_at(-1, weight_data); + + QuantizationConfig input_quant_config(32, kNoQuantization, {}); + QuantizationConfig weight_quant_config(8, kPerChannel, {K}); + + quantized_conv2d_impl( + graph, + input_quant_config, + weight_quant_config, + input_image, + kDummyValueRef, // input scale + kDummyValueRef, // input zero point + weight_data, + kDummyValueRef, // weight sums + weight_scales_data, + bias_data, + kernel_size, + stride, + padding, + dilation, + groups, + output_image); +} + +REGISTER_OPERATORS { + VK_REGISTER_OP(et_vk.conv2d_q8ta_q8csw.default, conv2d_q8ta_q8csw); + VK_REGISTER_OP(et_vk.conv2d_q8csw.default, conv2d_q8csw); +} + +} // namespace vkcompute diff --git a/backends/vulkan/test/custom_ops/CMakeLists.txt b/backends/vulkan/test/custom_ops/CMakeLists.txt index 6944fe59385..5ccc83c60e5 100644 --- a/backends/vulkan/test/custom_ops/CMakeLists.txt +++ b/backends/vulkan/test/custom_ops/CMakeLists.txt @@ -92,4 +92,5 @@ if(TARGET vulkan_backend) # Define operator prototypes add_operator_prototype(add) add_operator_prototype(q8csw_linear) + add_operator_prototype(q8csw_conv2d) endif() diff --git a/backends/vulkan/test/custom_ops/q8csw_conv2d.cpp b/backends/vulkan/test/custom_ops/q8csw_conv2d.cpp new file mode 100644 index 00000000000..d566e5b2646 --- /dev/null +++ b/backends/vulkan/test/custom_ops/q8csw_conv2d.cpp @@ -0,0 +1,785 @@ +// Copyright (c) Meta Platforms, Inc. and affiliates. +// All rights reserved. +// +// This source code is licensed under the BSD-style license found in the +// LICENSE file in the root directory of this source tree. + +#include +#include +#include +#include +#include "utils.h" + +#include + +using namespace executorch::vulkan::prototyping; + +using namespace vkcompute; + +static constexpr int64_t kRefDimSizeLimit = 100; + +// Component structs for better readability +struct KernelSize { + int32_t h; + int32_t w; + + KernelSize(int32_t height, int32_t width) : h(height), w(width) {} +}; + +struct Stride { + int32_t h; + int32_t w; + + Stride(int32_t height, int32_t width) : h(height), w(width) {} +}; + +struct Padding { + int32_t h; + int32_t w; + + Padding(int32_t height, int32_t width) : h(height), w(width) {} +}; + +struct Dilation { + int32_t h; + int32_t w; + + Dilation(int32_t height = 1, int32_t width = 1) : h(height), w(width) {} +}; + +struct OutInChannels { + int32_t out; + int32_t in; + + OutInChannels(int32_t out_channels, int32_t in_channels) + : out(out_channels), in(in_channels) {} +}; + +struct InputSize2D { + int32_t h; + int32_t w; + + InputSize2D(int32_t height, int32_t width) : h(height), w(width) {} +}; + +// Conv2d configuration struct +struct Conv2dConfig { + OutInChannels channels; + InputSize2D input_size; + KernelSize kernel; + Stride stride; + Padding padding; + Dilation dilation; + int32_t groups; // Number of groups for grouped convolution + std::string test_case_name = "placeholder"; + std::string op_name = "conv2d_q8ta_q8csw"; + + // Calculate output dimensions + int64_t get_output_height() const { + return (input_size.h + 2 * padding.h - dilation.h * (kernel.h - 1) - 1) / + stride.h + + 1; + } + + int64_t get_output_width() const { + return (input_size.w + 2 * padding.w - dilation.w * (kernel.w - 1) - 1) / + stride.w + + 1; + } +}; + +// Utility function to create a test case from a Conv2dConfig +TestCase create_test_case_from_config( + const Conv2dConfig& config, + utils::StorageType storage_type, + vkapi::ScalarType input_dtype) { + TestCase test_case; + + // Create a descriptive name for the test case + std::string storage_str = + (storage_type == utils::kTexture3D) ? "Texture3D" : "Buffer"; + std::string dtype_str = (input_dtype == vkapi::kFloat) ? "Float" : "Half"; + + std::string test_name = + config.test_case_name + "_" + storage_str + "_" + dtype_str; + test_case.set_name(test_name); + + // Set the operator name for the test case + std::string operator_name = "et_vk." + config.op_name + ".default"; + test_case.set_operator_name(operator_name); + + // Calculate output dimensions + int64_t H_out = config.get_output_height(); + int64_t W_out = config.get_output_width(); + + // Input tensor (float/half) - [1, C_in, H_in, W_in] (batch size always 1) + std::vector input_size = { + 1, config.channels.in, config.input_size.h, config.input_size.w}; + + ValueSpec input_tensor( + input_size, + input_dtype, + storage_type, + utils::kChannelsPacked, + DataGenType::RANDOM); + + if (debugging()) { + print_valuespec_data(input_tensor, "input_tensor"); + } + + float input_scale_val = 0.07f; + ValueSpec input_scale(input_scale_val); + + int32_t input_zero_point_val = -3; + ValueSpec input_zero_point(input_zero_point_val); + + // Quantized weight tensor (int8) - [C_out, C_in_per_group * K_h * K_w] + // Memory layout: height, width, then channels - in_c is innermost (stride 1) + // in the second dimension + const int64_t in_channels_per_group = config.channels.in / config.groups; + const int64_t in_features = utils::align_up_4( + in_channels_per_group * config.kernel.h * config.kernel.w); + std::vector weight_size = {config.channels.out, in_features}; + ValueSpec quantized_weight( + weight_size, + vkapi::kChar, // int8 for quantized weights + storage_type, + utils::kWidthPacked, + DataGenType::RANDINT8); + quantized_weight.set_constant(true); + + if (debugging()) { + print_valuespec_data(quantized_weight, "weight_tensor"); + } + + // Weight quantization scales (float/half, per-channel) + ValueSpec weight_scales( + {config.channels.out}, // Per output channel + input_dtype, + storage_type, + utils::kWidthPacked, + DataGenType::RANDOM_SCALES); + weight_scales.set_constant(true); + + ValueSpec weight_sums( + {config.channels.out}, // Per output channel + vkapi::kInt, + storage_type, + utils::kWidthPacked, + DataGenType::ZEROS); + weight_sums.set_constant(true); + + // Compute weight_sums data based on quantized weights + compute_weight_sums( + weight_sums, quantized_weight, config.channels.out, in_features); + + // Bias (optional, float/half) - [C_out] + ValueSpec bias( + {config.channels.out}, // Per output channel + input_dtype, + storage_type, + utils::kWidthPacked, + DataGenType::RANDOM); + bias.set_constant(true); + + // Stride and padding parameters + ValueSpec stride({config.stride.h, config.stride.w}); + ValueSpec padding({config.padding.h, config.padding.w}); + + // Dilation and groups parameters + ValueSpec dilation({config.dilation.h, config.dilation.w}); + ValueSpec groups(config.groups); + + // Kernel size parameters + ValueSpec kernel_size({config.kernel.h, config.kernel.w}); + + // Output tensor (float/half) - [1, C_out, H_out, W_out] (batch size always 1) + ValueSpec output( + {1, config.channels.out, H_out, W_out}, + input_dtype, + storage_type, + utils::kChannelsPacked, + DataGenType::ZEROS); + + // Add all specs to test case + if (config.op_name.find("q8ta") != std::string::npos) { + test_case.add_input_spec(input_tensor); + test_case.add_input_spec(input_scale); + test_case.add_input_spec(input_zero_point); + test_case.add_input_spec(quantized_weight); + test_case.add_input_spec(weight_sums); + test_case.add_input_spec(weight_scales); + test_case.add_input_spec(bias); + test_case.add_input_spec(kernel_size); + test_case.add_input_spec(stride); + test_case.add_input_spec(padding); + test_case.add_input_spec(dilation); + test_case.add_input_spec(groups); + } else { + test_case.add_input_spec(input_tensor); + test_case.add_input_spec(quantized_weight); + test_case.add_input_spec(weight_scales); + test_case.add_input_spec(bias); + test_case.add_input_spec(kernel_size); + test_case.add_input_spec(stride); + test_case.add_input_spec(padding); + test_case.add_input_spec(dilation); + test_case.add_input_spec(groups); + } + + test_case.add_output_spec(output); + + return test_case; +} + +// Generate easy test cases for quantized conv2d operation (for debugging) +std::vector generate_quantized_conv2d_easy_cases() { + std::vector test_cases; + + // Single simple configuration for debugging + Conv2dConfig config = { + OutInChannels(8, 3), // channels (out, in) + InputSize2D(8, 8), // input_size (h, w) + KernelSize(3, 3), // kernel + Stride(1, 1), // stride + Padding(0, 0), // padding + Dilation(1, 1), // dilation + 1, // groups + }; + + // Test with both storage types and data types for completeness + std::vector storage_types = {utils::kTexture3D}; + std::vector float_types = {vkapi::kFloat}; + + // Generate test cases for each combination + for (const auto& storage_type : storage_types) { + for (const auto& input_dtype : float_types) { + test_cases.push_back( + create_test_case_from_config(config, storage_type, input_dtype)); + } + } + + return test_cases; +} + +// Generate test cases for quantized conv2d operation +std::vector generate_quantized_conv2d_test_cases() { + std::vector test_cases; + + std::vector configs = { + {OutInChannels(32, 3), + InputSize2D(64, 64), + KernelSize(3, 3), + Stride(2, 2), + Padding(1, 1), + Dilation(1, 1), + 1}, + {OutInChannels(32, 16), + InputSize2D(32, 32), + KernelSize(3, 3), + Stride(1, 1), + Padding(1, 1), + Dilation(1, 1), + 1}, + {OutInChannels(64, 32), + InputSize2D(16, 16), + KernelSize(3, 3), + Stride(1, 1), + Padding(1, 1), + Dilation(1, 1), + 1}, + // One output channel case + {OutInChannels(1, 32), + InputSize2D(55, 55), + KernelSize(3, 3), + Stride(1, 1), + Padding(1, 1), + Dilation(1, 1), + 1}, + + // Stride 2 convolutions + {OutInChannels(32, 3), + InputSize2D(64, 64), + KernelSize(3, 3), + Stride(2, 2), + Padding(1, 1), + Dilation(1, 1), + 1}, + {OutInChannels(64, 32), + InputSize2D(32, 32), + KernelSize(3, 3), + Stride(2, 2), + Padding(1, 1), + Dilation(1, 1), + 1}, + // Different kernel sizes + {OutInChannels(32, 16), + InputSize2D(28, 28), + KernelSize(5, 5), + Stride(1, 1), + Padding(2, 2), + Dilation(1, 1), + 1}, + {OutInChannels(64, 32), + InputSize2D(14, 14), + KernelSize(7, 7), + Stride(1, 1), + Padding(3, 3), + Dilation(1, 1), + 1}, + + // Dilated convolutions + {OutInChannels(32, 16), + InputSize2D(32, 32), + KernelSize(3, 3), + Stride(1, 1), + Padding(2, 2), + Dilation(2, 2), + 1}, + {OutInChannels(64, 32), + InputSize2D(16, 16), + KernelSize(3, 3), + Stride(1, 1), + Padding(3, 3), + Dilation(3, 3), + 1}, + + // Grouped convolutions + {OutInChannels(32, 32), + InputSize2D(32, 32), + KernelSize(3, 3), + Stride(1, 1), + Padding(1, 1), + Dilation(1, 1), + 4}, + {OutInChannels(64, 64), + InputSize2D(16, 16), + KernelSize(3, 3), + Stride(1, 1), + Padding(1, 1), + Dilation(1, 1), + 8}, + // Performance test cases + {OutInChannels(256, 128), + InputSize2D(128, 128), + KernelSize(1, 1), + Stride(1, 1), + Padding(1, 1), + Dilation(1, 1), + 8}, + {OutInChannels(128, 64), + InputSize2D(128, 128), + KernelSize(3, 3), + Stride(1, 1), + Padding(1, 1), + Dilation(1, 1), + 1}}; + + // Test with different storage types and data types + std::vector storage_types = {utils::kTexture3D}; + + // Generate test cases for each combination + for (auto& config : configs) { + for (const auto& storage_type : storage_types) { + // Generate test case name programmatically + bool is_performance = config.channels.out > kRefDimSizeLimit || + config.channels.in > kRefDimSizeLimit || + config.input_size.h > kRefDimSizeLimit || + config.input_size.w > kRefDimSizeLimit; + std::string prefix = is_performance ? "performance_" : "correctness_"; + std::string suffix = std::to_string(config.channels.out) + "/" + + std::to_string(config.channels.in) + "_" + + std::to_string(config.input_size.h) + "/" + + std::to_string(config.input_size.w) + "_" + + std::to_string(config.kernel.h) + "/" + + std::to_string(config.kernel.w); + + config.test_case_name = prefix + suffix; + // The default operator tested is activation + weight quantized conv2d; + // however, only test this if the int8 dot product extension is supported + if (vkcompute::api::context() + ->adapter_ptr() + ->supports_int8_dot_product()) { + test_cases.push_back( + create_test_case_from_config(config, storage_type, vkapi::kFloat)); + } + + Conv2dConfig wo_quant_config = config; + wo_quant_config.op_name = "conv2d_q8csw"; + test_cases.push_back(create_test_case_from_config( + wo_quant_config, storage_type, vkapi::kFloat)); + } + } + + return test_cases; +} + +// Reference implementation for weight only quantized conv2d (fp accumulation) +void conv2d_q8csw_reference_impl(TestCase& test_case) { + int32_t idx = 0; + const ValueSpec& input_spec = test_case.inputs()[idx++]; + const ValueSpec& weight_spec = test_case.inputs()[idx++]; + const ValueSpec& weight_scales_spec = test_case.inputs()[idx++]; + const ValueSpec& bias_spec = test_case.inputs()[idx++]; + const ValueSpec& kernel_size_spec = test_case.inputs()[idx++]; + const ValueSpec& stride_spec = test_case.inputs()[idx++]; + const ValueSpec& padding_spec = test_case.inputs()[idx++]; + const ValueSpec& dilation_spec = test_case.inputs()[idx++]; + const ValueSpec& groups_spec = test_case.inputs()[idx++]; + + // Extract output specification (mutable reference) + ValueSpec& output_spec = test_case.outputs()[0]; + + // Get tensor dimensions + auto input_sizes = input_spec.get_tensor_sizes(); // [N, C_in, H_in, W_in] + auto weight_sizes = + weight_spec.get_tensor_sizes(); // [C_out, C_in_per_group * K_h * K_w] + auto output_sizes = + output_spec.get_tensor_sizes(); // [N, C_out, H_out, W_out] + + int64_t N = input_sizes[0]; + int64_t C_in = input_sizes[1]; + int64_t H_in = input_sizes[2]; + int64_t W_in = input_sizes[3]; + int64_t C_out = output_sizes[1]; + int64_t H_out = output_sizes[2]; + int64_t W_out = output_sizes[3]; + + // Get kernel dimensions from kernel_size ValueSpec + auto kernel_size_data = kernel_size_spec.get_int32_data(); + int64_t K_h = kernel_size_data[0]; + int64_t K_w = kernel_size_data[1]; + + // Get stride, padding, dilation, and groups + auto stride_data = stride_spec.get_int32_data(); + auto padding_data = padding_spec.get_int32_data(); + auto dilation_data = dilation_spec.get_int32_data(); + int64_t stride_h = stride_data[0]; + int64_t stride_w = stride_data[1]; + int64_t pad_h = padding_data[0]; + int64_t pad_w = padding_data[1]; + int64_t dilation_h = dilation_data[0]; + int64_t dilation_w = dilation_data[1]; + int64_t groups = groups_spec.get_int_value(); + + // Skip for large tensors since computation time will be extremely slow + if (N > kRefDimSizeLimit || C_in > kRefDimSizeLimit || + H_in > kRefDimSizeLimit || W_in > kRefDimSizeLimit || + C_out > kRefDimSizeLimit) { + throw std::invalid_argument( + "One or more dimensions exceed the allowed limit for reference implementation."); + } + + if (input_spec.dtype != vkapi::kFloat) { + throw std::invalid_argument("Unsupported dtype"); + } + + // Get raw data pointers + auto& input_data = input_spec.get_float_data(); + auto& weight_data = weight_spec.get_int8_data(); + auto& weight_scales_data = weight_scales_spec.get_float_data(); + auto& bias_data = bias_spec.get_float_data(); + + // Calculate channels per group for grouped convolution + int64_t C_in_per_group = C_in / groups; + int64_t C_out_per_group = C_out / groups; + + // Calculate number of output elements + int64_t num_output_elements = N * C_out * H_out * W_out; + + auto& ref_data = output_spec.get_ref_float_data(); + ref_data.resize(num_output_elements); + + const int in_features = utils::align_up_4(C_in_per_group * K_h * K_w); + + // Perform weight-only quantized conv2d operation (fp accumulation) + for (int64_t n = 0; n < N; ++n) { + for (int64_t out_c = 0; out_c < C_out; ++out_c) { + for (int64_t out_h = 0; out_h < H_out; ++out_h) { + for (int64_t out_w = 0; out_w < W_out; ++out_w) { + float sum = 0.0f; + + // Determine which group this output channel belongs to + int64_t group_idx = out_c / C_out_per_group; + int64_t in_c_start = group_idx * C_in_per_group; + int64_t in_c_end = (group_idx + 1) * C_in_per_group; + + // Convolution operation with dilation support and grouped convolution + for (int64_t in_c = in_c_start; in_c < in_c_end; ++in_c) { + for (int64_t kh = 0; kh < K_h; ++kh) { + for (int64_t kw = 0; kw < K_w; ++kw) { + // Calculate input position with dilation + int64_t in_h = out_h * stride_h - pad_h + kh * dilation_h; + int64_t in_w = out_w * stride_w - pad_w + kw * dilation_w; + + // Check bounds (zero padding) + if (in_h >= 0 && in_h < H_in && in_w >= 0 && in_w < W_in) { + // Get input value (keep as float) + int64_t input_idx = n * (C_in * H_in * W_in) + + in_c * (H_in * W_in) + in_h * W_in + in_w; + float input_val = input_data[input_idx]; + + // Get weight value and dequantize + // Weight layout: [C_out, C_in_per_group * K_h * K_w] + int64_t weight_idx = out_c * in_features + + (kh * (K_w * C_in_per_group) + kw * C_in_per_group + + (in_c % C_in_per_group)); + float dequant_weight = + (static_cast(weight_data[weight_idx])) * + weight_scales_data[out_c]; + + sum += input_val * dequant_weight; + } + } + } + } + + // Add bias and store result + sum += bias_data[out_c]; + int64_t output_idx = n * (C_out * H_out * W_out) + + out_c * (H_out * W_out) + out_h * W_out + out_w; + ref_data[output_idx] = sum; + } + } + } + } +} + +// Reference implementation for activation and weight quantized conv2d (int +// accumulation) +void conv2d_q8ta_q8csw_reference_impl(TestCase& test_case) { + // Extract input specifications + int32_t idx = 0; + const ValueSpec& input_spec = test_case.inputs()[idx++]; + const ValueSpec& input_scale_spec = test_case.inputs()[idx++]; + const ValueSpec& input_zeros_spec = test_case.inputs()[idx++]; + const ValueSpec& weight_spec = test_case.inputs()[idx++]; + const ValueSpec& weight_sums_spec = test_case.inputs()[idx++]; + (void)weight_sums_spec; + const ValueSpec& weight_scales_spec = test_case.inputs()[idx++]; + const ValueSpec& bias_spec = test_case.inputs()[idx++]; + const ValueSpec& kernel_size_spec = test_case.inputs()[idx++]; + const ValueSpec& stride_spec = test_case.inputs()[idx++]; + const ValueSpec& padding_spec = test_case.inputs()[idx++]; + const ValueSpec& dilation_spec = test_case.inputs()[idx++]; + const ValueSpec& groups_spec = test_case.inputs()[idx++]; + + // Extract output specification (mutable reference) + ValueSpec& output_spec = test_case.outputs()[0]; + + // Get tensor dimensions + auto input_sizes = input_spec.get_tensor_sizes(); // [N, C_in, H_in, W_in] + auto weight_sizes = + weight_spec.get_tensor_sizes(); // [C_out, C_in_per_group * K_h * K_w] + auto output_sizes = + output_spec.get_tensor_sizes(); // [N, C_out, H_out, W_out] + + int64_t N = input_sizes[0]; + int64_t C_in = input_sizes[1]; + int64_t H_in = input_sizes[2]; + int64_t W_in = input_sizes[3]; + int64_t C_out = output_sizes[1]; + int64_t H_out = output_sizes[2]; + int64_t W_out = output_sizes[3]; + + // Get kernel dimensions from kernel_size ValueSpec + auto kernel_size_data = kernel_size_spec.get_int32_data(); + int64_t K_h = kernel_size_data[0]; + int64_t K_w = kernel_size_data[1]; + + // Get stride, padding, dilation, and groups + auto stride_data = stride_spec.get_int32_data(); + auto padding_data = padding_spec.get_int32_data(); + auto dilation_data = dilation_spec.get_int32_data(); + int64_t stride_h = stride_data[0]; + int64_t stride_w = stride_data[1]; + int64_t pad_h = padding_data[0]; + int64_t pad_w = padding_data[1]; + int64_t dilation_h = dilation_data[0]; + int64_t dilation_w = dilation_data[1]; + int64_t groups = groups_spec.get_int_value(); + + // Skip for large tensors since computation time will be extremely slow + if (N > kRefDimSizeLimit || C_in > kRefDimSizeLimit || + H_in > kRefDimSizeLimit || W_in > kRefDimSizeLimit || + C_out > kRefDimSizeLimit) { + throw std::invalid_argument( + "One or more dimensions exceed the allowed limit for reference implementation."); + } + + if (input_spec.dtype != vkapi::kFloat) { + throw std::invalid_argument("Unsupported dtype"); + } + + // Get raw data pointers + auto& input_data = input_spec.get_float_data(); + const float input_scale = input_scale_spec.get_float_value(); + const int32_t input_zero_point = input_zeros_spec.get_int_value(); + + auto& weight_data = weight_spec.get_int8_data(); + auto& weight_scales_data = weight_scales_spec.get_float_data(); + auto& bias_data = bias_spec.get_float_data(); + + // Calculate channels per group for grouped convolution + int64_t C_in_per_group = C_in / groups; + int64_t C_out_per_group = C_out / groups; + + // Calculate number of output elements + int64_t num_output_elements = N * C_out * H_out * W_out; + + auto& ref_data = output_spec.get_ref_float_data(); + ref_data.resize(num_output_elements); + + const int in_features = utils::align_up_4(C_in_per_group * K_h * K_w); + + // Perform activation and weight quantized conv2d operation (int accumulation) + for (int64_t n = 0; n < N; ++n) { + for (int64_t out_c = 0; out_c < C_out; ++out_c) { + for (int64_t out_h = 0; out_h < H_out; ++out_h) { + for (int64_t out_w = 0; out_w < W_out; ++out_w) { + int32_t int_sum = 0; + int32_t weight_sum = 0; // Track weight sum on the fly + + // Determine which group this output channel belongs to + int64_t group_idx = out_c / C_out_per_group; + int64_t in_c_start = group_idx * C_in_per_group; + int64_t in_c_end = (group_idx + 1) * C_in_per_group; + + // Convolution operation with integer accumulation + for (int64_t in_c = in_c_start; in_c < in_c_end; ++in_c) { + for (int64_t kh = 0; kh < K_h; ++kh) { + for (int64_t kw = 0; kw < K_w; ++kw) { + // Calculate input position with dilation + int64_t in_h = out_h * stride_h - pad_h + kh * dilation_h; + int64_t in_w = out_w * stride_w - pad_w + kw * dilation_w; + + // Check bounds (zero padding) + if (in_h >= 0 && in_h < H_in && in_w >= 0 && in_w < W_in) { + // Get input value and quantize to int8 + int64_t input_idx = n * (C_in * H_in * W_in) + + in_c * (H_in * W_in) + in_h * W_in + in_w; + + float quant_input_f = + std::round(input_data[input_idx] / input_scale) + + input_zero_point; + quant_input_f = + std::min(std::max(quant_input_f, -128.0f), 127.0f); + int8_t quantized_input = static_cast(quant_input_f); + + // Get quantized weight (already int8) + // Weight layout: [C_out, C_in_per_group * K_h * K_w] + int64_t weight_idx = out_c * in_features + + (kh * (K_w * C_in_per_group) + kw * C_in_per_group + + (in_c % C_in_per_group)); + int8_t quantized_weight = weight_data[weight_idx]; + + // Integer multiplication and accumulation + int_sum += static_cast(quantized_input) * + static_cast(quantized_weight); + + // Track weight sum for this output channel on the fly + weight_sum += static_cast(quantized_weight); + } else { + // For zero padding, we still need to account for the weight + // in weight_sum when input is effectively 0 (but quantized 0 + // is input_zero_point) + int64_t weight_idx = out_c * in_features + + (kh * (K_w * C_in_per_group) + kw * C_in_per_group + + (in_c % C_in_per_group)); + int8_t quantized_weight = weight_data[weight_idx]; + + // Add contribution from zero-padded input (quantized zero = + // input_zero_point) + int_sum += static_cast(input_zero_point) * + static_cast(quantized_weight); + + // Track weight sum for this output channel on the fly + weight_sum += static_cast(quantized_weight); + } + } + } + } + + // Convert accumulated integer result to float and apply scales + // Final result = (int_sum - zero_point_correction) * input_scale * + // weight_scale + bias zero_point_correction = input_zero_point * + // sum_of_weights_for_this_output_channel + int32_t zero_point_correction = input_zero_point * weight_sum; + int32_t accum_adjusted = int_sum - zero_point_correction; + float float_result = + accum_adjusted * input_scale * weight_scales_data[out_c]; + + // Add bias and store result + float_result += bias_data[out_c]; + int64_t output_idx = n * (C_out * H_out * W_out) + + out_c * (H_out * W_out) + out_h * W_out + out_w; + ref_data[output_idx] = float_result; + } + } + } + } +} + +void reference_impl(TestCase& test_case) { + if (test_case.operator_name().find("q8ta") != std::string::npos) { + conv2d_q8ta_q8csw_reference_impl(test_case); + } else { + conv2d_q8csw_reference_impl(test_case); + } +} + +// Custom FLOP calculator for quantized conv2d operation +int64_t quantized_conv2d_flop_calculator(const TestCase& test_case) { + int kernel_idx = 4; + if (test_case.operator_name().find("q8ta") != std::string::npos) { + kernel_idx = 7; + } + // Get input and weight dimensions + const auto& input_sizes = test_case.inputs()[0].get_tensor_sizes(); + const auto& output_sizes = test_case.outputs()[0].get_tensor_sizes(); + + const auto& kernel_sizes = test_case.inputs()[kernel_idx].get_int32_data(); + + int64_t N = input_sizes[0]; + int64_t C_in = input_sizes[1]; + int64_t C_out = output_sizes[1]; + int64_t K_h = kernel_sizes[0]; + int64_t K_w = kernel_sizes[1]; + int64_t H_out = output_sizes[2]; + int64_t W_out = output_sizes[3]; + + // Calculate FLOPs for quantized conv2d operation + // Each output element requires: + // - C_in * K_h * K_w multiply-accumulate operations + // - Additional operations for quantization/dequantization + int64_t output_elements = N * C_out * H_out * W_out; + int64_t ops_per_output = C_in * K_h * K_w; + + int64_t flop = output_elements * (ops_per_output); + + return flop; +} + +int main(int argc, char* argv[]) { + set_debugging(false); + set_print_output(false); + set_print_latencies(false); + set_use_gpu_timestamps(true); + + print_performance_header(); + std::cout << "Quantized Conv2d Operation Prototyping Framework" << std::endl; + print_separator(); + + ReferenceComputeFunc ref_fn = reference_impl; + + // Execute test cases using the new framework with custom FLOP calculator + auto results = execute_test_cases( + generate_quantized_conv2d_test_cases, + quantized_conv2d_flop_calculator, + "QuantizedConv2d", + 0, + 10, + ref_fn); + + return 0; +} diff --git a/backends/vulkan/test/custom_ops/targets.bzl b/backends/vulkan/test/custom_ops/targets.bzl index 68bdc9e6fbd..4297565da80 100644 --- a/backends/vulkan/test/custom_ops/targets.bzl +++ b/backends/vulkan/test/custom_ops/targets.bzl @@ -94,3 +94,4 @@ def define_common_targets(is_fbcode = False): define_custom_op_test_binary("add") define_custom_op_test_binary("q8csw_linear") + define_custom_op_test_binary("q8csw_conv2d")