diff --git a/.github/workflows/pull.yml b/.github/workflows/pull.yml index 4215db1e2ca..c2f346d4c84 100644 --- a/.github/workflows/pull.yml +++ b/.github/workflows/pull.yml @@ -1009,6 +1009,7 @@ jobs: ./cmake-out/backends/vulkan/test/custom_ops/q8csw_conv2d ./cmake-out/backends/vulkan/test/custom_ops/q4gsw_linear ./cmake-out/backends/vulkan/test/custom_ops/choose_qparams_per_row + ./cmake-out/backends/vulkan/test/custom_ops/qdq8ta_conv2d_activations # "Classic" Operator tests PYTHON_EXECUTABLE=python bash backends/vulkan/test/scripts/test_op.sh --build diff --git a/backends/vulkan/runtime/graph/ops/glsl/common.glslh b/backends/vulkan/runtime/graph/ops/glsl/common.glslh index 732b7006c2c..00a053612f5 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/common.glslh +++ b/backends/vulkan/runtime/graph/ops/glsl/common.glslh @@ -33,6 +33,30 @@ struct TensorIndex4D { ivec4 data; }; +int sign_extend_8bit(const int val) { + if ((val & 0x80) != 0) { + return val | (~0xFF); + } + return val; +} + +int extract_8bit_from_packed_int_le(const int packed, const int i) { + // account for little endian + int byte = sign_extend_8bit(packed >> (8 * i) & 0xFF); + return byte; +} + +int pack_4xqint_into_int32( + const int val0, + const int val1, + const int val2, + const int val3) { + int packed = (val0 & 0xFF) | ((val1 & 0xFF) << 8) | ((val2 & 0xFF) << 16) | + ((val3 & 0xFF) << 24); + + return packed; +} + #ifdef DEBUG_MODE #extension GL_EXT_debug_printf : require diff --git a/backends/vulkan/runtime/graph/ops/glsl/conv2d_common.glslh b/backends/vulkan/runtime/graph/ops/glsl/conv2d_common.glslh index 41825cba867..929f3da299e 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/conv2d_common.glslh +++ b/backends/vulkan/runtime/graph/ops/glsl/conv2d_common.glslh @@ -27,6 +27,48 @@ struct Conv2DParams { int K4; }; +struct Conv2dTensorIndex { + ivec3 data; + int texel_i; +}; + +struct Conv2dBlockIndex { + ivec3 data; +}; + +Conv2dTensorIndex block_idx_to_tensor_idx(const Conv2dBlockIndex block_idx) { + Conv2dTensorIndex tensor_idx; + tensor_idx.data.x = mul_4(block_idx.data.x); + tensor_idx.data.y = block_idx.data.y; + tensor_idx.data.z = block_idx.data.z; + tensor_idx.texel_i = 0; + return tensor_idx; +} + +struct Conv2dBlockExtents { + ivec3 data; + int data_xz; +}; + +Conv2dBlockExtents make_block_extents(const ivec4 tensor_sizes) { + Conv2dBlockExtents block_sizes; + block_sizes.data.x = div_up_4(tensor_sizes.x); + block_sizes.data.y = tensor_sizes.y; + block_sizes.data.z = div_up_4(tensor_sizes.z); + + block_sizes.data_xz = block_sizes.data.x * block_sizes.data.z; + + return block_sizes; +} + +bool block_idx_out_of_bounds( + const Conv2dBlockIndex block_idx, + const Conv2dBlockExtents block_extents) { + return block_idx.data.x >= block_extents.data.x || + block_idx.data.y >= block_extents.data.y || + block_idx.data.z >= block_extents.data.z; +} + #ifdef DEBUG_MODE void printConv2DParams(const Conv2DParams params) { diff --git a/backends/vulkan/runtime/graph/ops/glsl/conv2d_fp_input_tile_load.glslh b/backends/vulkan/runtime/graph/ops/glsl/conv2d_fp_input_tile_load.glslh new file mode 100644 index 00000000000..be8a76421a5 --- /dev/null +++ b/backends/vulkan/runtime/graph/ops/glsl/conv2d_fp_input_tile_load.glslh @@ -0,0 +1,34 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#ifndef CONV2D_FP_INPUT_TILE_LOAD +#define CONV2D_FP_INPUT_TILE_LOAD + +#extension GL_EXT_control_flow_attributes : require + +#include "linear_fp_input_tile.glslh" + +VEC4_T load_fp_input_texel(const Conv2dTensorIndex tidx) { + return texelFetch(t_fp_input, tidx.data, 0); +} + +void load_fp_input_tile( + out FPInputTile tile, + const Conv2dBlockIndex block_idx) { +#if TILE_M == 4 && TILE_K4 == 1 + Conv2dTensorIndex load_tidx = block_idx_to_tensor_idx(block_idx); + [[unroll]] for (int w = 0; w < TILE_M; w++) { + tile.data[w][0] = load_fp_input_texel(load_tidx); + load_tidx.data.x++; + } +#else + not_implemented; +#endif +} + +#endif // CONV2D_FP_INPUT_TILE_LOAD diff --git a/backends/vulkan/runtime/graph/ops/glsl/linear_common.glslh b/backends/vulkan/runtime/graph/ops/glsl/linear_common.glslh index da326b26e93..c95abdcb230 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/linear_common.glslh +++ b/backends/vulkan/runtime/graph/ops/glsl/linear_common.glslh @@ -16,19 +16,6 @@ #include "common.glslh" -int sign_extend_8bit(const int val) { - if ((val & 0x80) != 0) { - return val | (~0xFF); - } - return val; -} - -int extract_8bit_from_packed_int_le(const int packed, const int i) { - // account for little endian - int byte = sign_extend_8bit(packed >> (8 * i) & 0xFF); - return byte; -} - // Extract a 4-bit value from a packed int (little endian) // It is assumed that the 4-bit value is in the range [0, 15] int extract_4bit_from_packed_int_le(const int packed, const int col) { diff --git a/backends/vulkan/runtime/graph/ops/glsl/quantize_and_pack_q8ta_conv2d_input.glsl b/backends/vulkan/runtime/graph/ops/glsl/quantize_and_pack_q8ta_conv2d_input.glsl new file mode 100644 index 00000000000..d485523709b --- /dev/null +++ b/backends/vulkan/runtime/graph/ops/glsl/quantize_and_pack_q8ta_conv2d_input.glsl @@ -0,0 +1,77 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#version 450 core + +#define PRECISION ${PRECISION} +#define VEC4_T ${texel_load_type(DTYPE, INPUT_STORAGE)} +#define T ${texel_load_component_type(DTYPE, INPUT_STORAGE)} + +// corresponds to the input width dim +#define TILE_M4 1 +// corresponds to the input channels dim +#define TILE_K4 1 + +#define TILE_M 4 + +$if OUTPUT_STORAGE == "buffer": + #define OUTPUT_BUFFER +$if INPUT_STORAGE == "buffer": + #define INPUT_BUFFER + +${define_required_extensions(DTYPE)} + +layout(std430) buffer; + +#include "conv2d_common.glslh" + +${layout_declare_tensor(B, "w", "t_packed_int8_input", "int", OUTPUT_STORAGE, is_scalar_array=False)} +${layout_declare_tensor(B, "r", "t_fp_input", DTYPE, INPUT_STORAGE, is_scalar_array=False)} + +${layout_declare_ubo(B, "ivec4", "input_sizes")} + +layout(push_constant) uniform restrict Block { + float inv_scale; + int zp; +}; + +layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in; + +#include "conv2d_fp_input_tile_load.glslh" +#include "linear_int8_input_block.glslh" + +void store_packed_int8_block( + const Conv2dBlockIndex block_idx, + const Conv2dBlockExtents block_extents, + const Int8InputBlock packed_int8_block) { +#ifdef OUTPUT_BUFFER + const int buffer_idx = block_idx.data.y * block_extents.data_xz + + block_idx.data.x * block_extents.data.z + block_idx.data.z; + t_packed_int8_input[buffer_idx] = packed_int8_block.data; +#else + imageStore(t_packed_int8_input, block_idx.data, packed_int8_block.data); +#endif +} + +void main() { + Conv2dBlockIndex block_idx; + block_idx.data = ivec3(gl_GlobalInvocationID); + + Conv2dBlockExtents block_extents = make_block_extents(input_sizes); + if (block_idx_out_of_bounds(block_idx, block_extents)) { + return; + } + + FPInputTile fp_tile; + load_fp_input_tile(fp_tile, block_idx); + + Int8InputBlock int8_block; + quantize_and_pack(int8_block, fp_tile, inv_scale, zp); + + store_packed_int8_block(block_idx, block_extents, int8_block); +} diff --git a/backends/vulkan/runtime/graph/ops/glsl/quantize_and_pack_q8ta_conv2d_input.yaml b/backends/vulkan/runtime/graph/ops/glsl/quantize_and_pack_q8ta_conv2d_input.yaml new file mode 100644 index 00000000000..712d3156e2e --- /dev/null +++ b/backends/vulkan/runtime/graph/ops/glsl/quantize_and_pack_q8ta_conv2d_input.yaml @@ -0,0 +1,21 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +quantize_and_pack_q8ta_conv2d_input: + parameter_names_with_default_values: + DTYPE: float + OUTPUT_STORAGE: texture3d + INPUT_STORAGE: texture3d + generate_variant_forall: + combination: + parameter_names: [OUTPUT_STORAGE, INPUT_STORAGE] + combos: + - parameter_values: [texture3d, texture3d] + - parameter_values: [buffer, texture3d] + DTYPE: + - VALUE: float + shader_variants: + - NAME: quantize_and_pack_q8ta_conv2d_input diff --git a/backends/vulkan/runtime/graph/ops/glsl/unpack_and_dequantize_q8ta_conv2d_output.glsl b/backends/vulkan/runtime/graph/ops/glsl/unpack_and_dequantize_q8ta_conv2d_output.glsl new file mode 100644 index 00000000000..ed7dd25421a --- /dev/null +++ b/backends/vulkan/runtime/graph/ops/glsl/unpack_and_dequantize_q8ta_conv2d_output.glsl @@ -0,0 +1,118 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#version 450 core + +#define PRECISION ${PRECISION} +#define VEC4_T ${texel_load_type(DTYPE, INPUT_STORAGE)} +#define T ${texel_load_component_type(DTYPE, INPUT_STORAGE)} + +// corresponds to the output width dim +#define TILE_M4 1 +// corresponds to the output channels dim +#define TILE_K4 1 + +#define TILE_M 4 + +$if OUTPUT_STORAGE == "buffer": + #define OUTPUT_BUFFER +$if INPUT_STORAGE == "buffer": + #define INPUT_BUFFER + +${define_required_extensions(DTYPE)} + +layout(std430) buffer; + +#define DEBUG_MODE +#include "conv2d_common.glslh" + +${layout_declare_tensor(B, "w", "t_fp_output", DTYPE, OUTPUT_STORAGE, is_scalar_array=False)} +${layout_declare_tensor(B, "r", "t_packed_int8_output", "int", INPUT_STORAGE, is_scalar_array=False)} + +${layout_declare_ubo(B, "ivec4", "output_sizes")} + +layout(push_constant) uniform restrict Block { + float scale; + int zp; +}; + +layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in; + +#include "linear_fp_input_tile.glslh" +#include "linear_int8_input_tile.glslh" + +void load_packed_int8_tile( + out Int8InputTile int8_tile, + const Conv2dBlockIndex block_idx, + const Conv2dBlockExtents block_extents) { +#ifdef INPUT_BUFFER + const int buffer_idx = block_idx.data.y * block_extents.data_xz + + block_idx.data.x * block_extents.data.z + block_idx.data.z; + int8_tile.data[0][0] = t_packed_int8_output[buffer_idx]; +#else + int8_tile.data[0][0] = texelFetch(t_packed_int8_output, block_idx.data, 0); +#endif +} + +VEC4_T +dequantize_8bit(const ivec4 val, const float q_scale, const int q_zero_point) { + return VEC4_T(val - q_zero_point) * q_scale; +} + +void unpack_and_dequantize( + out FPInputTile fp_tile, + const Int8InputTile int8_tile, + const float q_scale, + const int q_zero_point) { + [[unroll]] for (int w = 0; w < 4; ++w) { + int packed = int8_tile.data[0][0][w]; + fp_tile.data[w][0] = dequantize_8bit( + ivec4( + extract_8bit_from_packed_int_le(packed, 0), + extract_8bit_from_packed_int_le(packed, 1), + extract_8bit_from_packed_int_le(packed, 2), + extract_8bit_from_packed_int_le(packed, 3)), + q_scale, + q_zero_point); + } +} + +void store_fp_output_texel( + const Conv2dTensorIndex tidx, + const VEC4_T out_texel) { + imageStore(t_fp_output, tidx.data, out_texel); +} + +void store_fp_tile( + const FPInputTile block, + const Conv2dBlockIndex block_idx) { + Conv2dTensorIndex store_tidx = block_idx_to_tensor_idx(block_idx); + [[unroll]] for (int w = 0; w < 4; w++) { + store_fp_output_texel(store_tidx, block.data[w][0]); + store_tidx.data.x++; + } +} + +void main() { + Conv2dBlockIndex block_idx; + block_idx.data = ivec3(gl_GlobalInvocationID); + + Conv2dBlockExtents block_extents = make_block_extents(output_sizes); + if (block_idx_out_of_bounds(block_idx, block_extents)) { + return; + } + + Int8InputTile int8_tile; + load_packed_int8_tile(int8_tile, block_idx, block_extents); + + FPInputTile fp_tile; + unpack_and_dequantize( + fp_tile, int8_tile, scale, zp); + + store_fp_tile(fp_tile, block_idx); +} diff --git a/backends/vulkan/runtime/graph/ops/glsl/unpack_and_dequantize_q8ta_conv2d_output.yaml b/backends/vulkan/runtime/graph/ops/glsl/unpack_and_dequantize_q8ta_conv2d_output.yaml new file mode 100644 index 00000000000..24b253da343 --- /dev/null +++ b/backends/vulkan/runtime/graph/ops/glsl/unpack_and_dequantize_q8ta_conv2d_output.yaml @@ -0,0 +1,21 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +unpack_and_dequantize_q8ta_conv2d_output: + parameter_names_with_default_values: + DTYPE: float + OUTPUT_STORAGE: texture3d + INPUT_STORAGE: texture3d + generate_variant_forall: + combination: + parameter_names: [OUTPUT_STORAGE, INPUT_STORAGE] + combos: + - parameter_values: [texture3d, texture3d] + - parameter_values: [texture3d, buffer] + DTYPE: + - VALUE: float + shader_variants: + - NAME: unpack_and_dequantize_q8ta_conv2d_output diff --git a/backends/vulkan/runtime/graph/ops/impl/QuantizedConvolution.cpp b/backends/vulkan/runtime/graph/ops/impl/QuantizedConvolution.cpp index 9fc9fd52ad6..f6eee4ba12e 100644 --- a/backends/vulkan/runtime/graph/ops/impl/QuantizedConvolution.cpp +++ b/backends/vulkan/runtime/graph/ops/impl/QuantizedConvolution.cpp @@ -156,6 +156,40 @@ std::vector calculate_output_im2col_sizes( // Shader dispatch utilities // +utils::uvec3 pick_quantize_and_pack_conv2d_input_global_wg_size( + ComputeGraph* graph, + const vkapi::ShaderInfo& shader, + const std::vector& args, + const std::vector& resize_args) { + const ValueRef fp_input = args.at(1).refs.at(0); + + const uint32_t W = graph->size_at(-1, fp_input); + const uint32_t H = graph->size_at(-2, fp_input); + const uint32_t C = graph->size_at(-3, fp_input); + + const uint32_t W4 = utils::div_up_4(W); + const uint32_t C4 = utils::div_up_4(C); + + return {W4, H, C4}; +} + +utils::uvec3 pick_unpack_and_dequantize_conv2d_output_global_wg_size( + ComputeGraph* graph, + const vkapi::ShaderInfo& shader, + const std::vector& args, + const std::vector& resize_args) { + const ValueRef fp_output = args.at(0).refs.at(0); + + const uint32_t W = graph->size_at(-1, fp_output); + const uint32_t H = graph->size_at(-2, fp_output); + const uint32_t C = graph->size_at(-3, fp_output); + + const uint32_t W4 = utils::div_up_4(W); + const uint32_t C4 = utils::div_up_4(C); + + return {W4, H, C4}; +} + utils::uvec3 im2col_global_wg_size( ComputeGraph* graph, const vkapi::ShaderInfo& shader, @@ -251,6 +285,94 @@ void add_input_im2col_node( nullptr)); } +void add_quantize_and_pack_q8ta_conv2d_input_node( + ComputeGraph& graph, + const ValueRef fp_input, + const ValueRef input_scale, + const ValueRef input_zp, + const ValueRef packed_int8_input) { + float inv_scale = 1.0f / graph.extract_scalar(input_scale); + int32_t zp = graph.extract_scalar(input_zp); + + // Get shader for quantized conv2d linear tiled + std::string kernel_name = "quantize_and_pack_q8ta_conv2d_input"; + add_storage_type_suffix( + kernel_name, graph.storage_type_of(packed_int8_input)); + add_storage_type_suffix(kernel_name, graph.storage_type_of(fp_input)); + add_dtype_suffix(kernel_name, graph.dtype_of(fp_input)); + + vkapi::ShaderInfo shader = VK_KERNEL_FROM_STR(kernel_name); + + vkapi::ParamsBindList param_buffers = {graph.sizes_ubo(fp_input)}; + + std::vector push_constants = { + PushConstantDataInfo(&inv_scale, sizeof(inv_scale)), + PushConstantDataInfo(&zp, sizeof(zp)), + }; + + graph.execute_nodes().emplace_back(new DynamicDispatchNode( + graph, + VK_KERNEL_FROM_STR(kernel_name), + pick_quantize_and_pack_conv2d_input_global_wg_size, + default_pick_local_wg_size, + // Inputs and Outputs + {{packed_int8_input, vkapi::kWrite}, {fp_input, vkapi::kRead}}, + // Shader params buffers + param_buffers, + // Push Constants + push_constants, + // Specialization Constants + {}, + // Resize args + {}, + // Resizing Logic + nullptr)); +} + +void add_unpack_and_dequantize_q8ta_conv2d_output_node( + ComputeGraph& graph, + const ValueRef packed_int8_output, + const ValueRef output_scale, + const ValueRef output_zp, + const ValueRef fp_output) { + float scale = graph.extract_scalar(output_scale); + int32_t zp = graph.extract_scalar(output_zp); + + // Get shader for quantized conv2d linear tiled + std::string kernel_name = "unpack_and_dequantize_q8ta_conv2d_output"; + add_storage_type_suffix(kernel_name, graph.storage_type_of(fp_output)); + add_storage_type_suffix( + kernel_name, graph.storage_type_of(packed_int8_output)); + add_dtype_suffix(kernel_name, graph.dtype_of(fp_output)); + + vkapi::ShaderInfo shader = VK_KERNEL_FROM_STR(kernel_name); + + vkapi::ParamsBindList param_buffers = {graph.sizes_ubo(fp_output)}; + + std::vector push_constants = { + PushConstantDataInfo(&scale, sizeof(scale)), + PushConstantDataInfo(&zp, sizeof(zp)), + }; + + graph.execute_nodes().emplace_back(new DynamicDispatchNode( + graph, + VK_KERNEL_FROM_STR(kernel_name), + pick_unpack_and_dequantize_conv2d_output_global_wg_size, + default_pick_local_wg_size, + // Inputs and Outputs + {{fp_output, vkapi::kWrite}, {packed_int8_output, vkapi::kRead}}, + // Shader params buffers + param_buffers, + // Push Constants + push_constants, + // Specialization Constants + {}, + // Resize args + {}, + // Resizing Logic + nullptr)); +} + void add_quantize_and_pack_im2col_node( ComputeGraph& graph, const ValueRef input_image, @@ -683,9 +805,37 @@ void conv2d_q8csw(ComputeGraph& graph, const std::vector& args) { output_image); } +// +// Quantize and dequantize operators +// + +void qdq8ta_conv2d_input( + ComputeGraph& graph, + const std::vector& args) { + int32_t idx = 0; + const ValueRef fp_input = args.at(idx++); + const ValueRef scale = args.at(idx++); + const ValueRef zero_point = args.at(idx++); + const ValueRef fp_output = args.at(idx++); + + TmpTensor packed_int8_input( + &graph, + graph.sizes_of(fp_input), + vkapi::kInt8x4, + utils::kBuffer, + utils::kPackedInt8_4W4C); + + add_quantize_and_pack_q8ta_conv2d_input_node( + graph, fp_input, scale, zero_point, packed_int8_input); + + add_unpack_and_dequantize_q8ta_conv2d_output_node( + graph, packed_int8_input, scale, zero_point, fp_output); +} + REGISTER_OPERATORS { VK_REGISTER_OP(et_vk.conv2d_q8ta_q8csw.default, conv2d_q8ta_q8csw); VK_REGISTER_OP(et_vk.conv2d_q8csw.default, conv2d_q8csw); + VK_REGISTER_OP(etvk.qdq8ta_conv2d_input.default, qdq8ta_conv2d_input); } } // namespace vkcompute diff --git a/backends/vulkan/test/custom_ops/CMakeLists.txt b/backends/vulkan/test/custom_ops/CMakeLists.txt index 97b632338db..fe36de3047e 100644 --- a/backends/vulkan/test/custom_ops/CMakeLists.txt +++ b/backends/vulkan/test/custom_ops/CMakeLists.txt @@ -95,4 +95,5 @@ if(TARGET vulkan_backend) add_operator_prototype(q8csw_conv2d) add_operator_prototype(q4gsw_linear) add_operator_prototype(choose_qparams_per_row) + add_operator_prototype(qdq8ta_conv2d_activations) endif() diff --git a/backends/vulkan/test/custom_ops/qdq8ta_conv2d_activations.cpp b/backends/vulkan/test/custom_ops/qdq8ta_conv2d_activations.cpp new file mode 100644 index 00000000000..5275e6c9335 --- /dev/null +++ b/backends/vulkan/test/custom_ops/qdq8ta_conv2d_activations.cpp @@ -0,0 +1,251 @@ +// Copyright (c) Meta Platforms, Inc. and affiliates. +// All rights reserved. +// +// This source code is licensed under the BSD-style license found in the +// LICENSE file in the root directory of this source tree. + +#include +#include +#include +#include +#include +#include +#include "utils.h" + +#include + +using namespace executorch::vulkan::prototyping; +using namespace vkcompute; + +static constexpr int64_t kRefDimSizeLimit = 512; + +// QDQ8TA Conv2D configuration struct for 4D tensor quantize-dequantize testing +struct QDQ8TAConv2DConfig { + int64_t batch_size; // N dimension + int64_t in_channels; // C dimension + int64_t height; // H dimension + int64_t width; // W dimension + std::string test_case_name = "placeholder"; + std::string op_name = "qdq8ta_conv2d_input"; +}; + +// Utility function to create a test case from a QDQ8TAConv2DConfig +TestCase create_test_case_from_config( + const QDQ8TAConv2DConfig& config, + utils::StorageType storage_type, + vkapi::ScalarType input_dtype) { + TestCase test_case; + + // Create a descriptive name for the test case + std::string storage_str = + (storage_type == utils::kTexture3D) ? "Texture3D" : "Buffer"; + std::string dtype_str = (input_dtype == vkapi::kFloat) ? "Float" : "Half"; + + std::string test_name = + config.test_case_name + "_" + storage_str + "_" + dtype_str; + test_case.set_name(test_name); + + // Set the operator name for the test case + std::string operator_name = "etvk." + config.op_name + ".default"; + test_case.set_operator_name(operator_name); + + // Input tensor (float) - [N, C, H, W] + std::vector input_size = { + config.batch_size, config.in_channels, config.height, config.width}; + ValueSpec input_tensor( + input_size, + input_dtype, + storage_type, + utils::kChannelsPacked, // Use channels packed for conv2d tensors + DataGenType::RANDOM); + + float scale_val = 0.007112; + ValueSpec scale(scale_val); + + // Generate random zero point within quantization range + int32_t zero_point_val = -2; + ValueSpec zero_point(zero_point_val); + + // Output tensor (float) - same shape as input [N, C, H, W] + ValueSpec output_tensor( + input_size, + input_dtype, + storage_type, + utils::kChannelsPacked, + DataGenType::ZEROS); + + // Add all specs to test case + test_case.add_input_spec(input_tensor); + test_case.add_input_spec(scale); + test_case.add_input_spec(zero_point); + test_case.add_output_spec(output_tensor); + + test_case.set_abs_tolerance(scale_val + 1e-4); + + return test_case; +} + +// Generate easy test cases for qdq8ta_conv2d operation (for debugging) +std::vector generate_qdq8ta_conv2d_easy_cases() { + std::vector test_cases; + + // Single simple configuration for debugging + QDQ8TAConv2DConfig config = { + 1, // batch_size + 3, // in_channels + 4, // height + 4, // width + "simple", // test_case_name + }; + + // Test with both storage types + std::vector storage_types = {utils::kTexture3D}; + std::vector float_types = {vkapi::kFloat}; + + // Generate test cases for each combination + for (const auto& storage_type : storage_types) { + for (const auto& input_dtype : float_types) { + test_cases.push_back( + create_test_case_from_config(config, storage_type, input_dtype)); + } + } + + return test_cases; +} + +// Generate test cases for qdq8ta_conv2d operation +std::vector generate_qdq8ta_conv2d_test_cases() { + std::vector test_cases; + + std::vector configs = { + // Small test cases for correctness + {1, 3, 16, 16}, + {1, 8, 32, 32}, + {1, 16, 24, 24}, + {1, 32, 12, 12}, + {1, 1, 64, 64}, + {1, 3, 64, 64}, + {1, 4, 16, 16}, + + // Different tensor sizes + {1, 8, 20, 20}, + {1, 16, 14, 14}, + {1, 8, 28, 28}, + + // Odd tensor sizes + {1, 3, 15, 15}, + {1, 13, 31, 31}, + {1, 17, 23, 23}, + + // Performance test cases (larger tensors) + {1, 64, 128, 128}, + {1, 32, 64, 64}, + {1, 128, 56, 56}, + }; + + // Test with different storage types + std::vector storage_types = {utils::kTexture3D}; + + for (auto config : configs) { + std::string prefix = + (config.batch_size < kRefDimSizeLimit && + config.in_channels < kRefDimSizeLimit && + config.height < kRefDimSizeLimit && config.width < kRefDimSizeLimit) + ? "correctness_" + : "performance_"; + std::string generated_test_case_name = prefix + + std::to_string(config.batch_size) + "_" + + std::to_string(config.in_channels) + "_" + + std::to_string(config.height) + "_" + std::to_string(config.width); + + config.test_case_name = generated_test_case_name; + + for (const auto& storage_type : storage_types) { + test_cases.push_back( + create_test_case_from_config(config, storage_type, vkapi::kFloat)); + } + } + + return test_cases; +} + +// Reference implementation for qdq8ta_conv2d operation +void qdq8ta_conv2d_reference_impl(TestCase& test_case) { + int32_t idx = 0; + const ValueSpec& input_spec = test_case.inputs()[idx++]; + const ValueSpec& scale_spec = test_case.inputs()[idx++]; + const ValueSpec& zero_point_spec = test_case.inputs()[idx++]; + + // Extract output specification + ValueSpec& output_spec = test_case.outputs()[0]; + + // Get tensor dimensions + auto input_sizes = input_spec.get_tensor_sizes(); // [N, C, H, W] + int64_t N = input_sizes[0]; + int64_t C = input_sizes[1]; + int64_t H = input_sizes[2]; + int64_t W = input_sizes[3]; + + // Skip for large tensors since computation time will be extremely slow + if (N > kRefDimSizeLimit || C > kRefDimSizeLimit || H > kRefDimSizeLimit || + W > kRefDimSizeLimit) { + throw std::invalid_argument( + "One or more dimensions (N, C, H, W) exceed the allowed limit for reference implementation."); + } + + if (input_spec.dtype != vkapi::kFloat) { + throw std::invalid_argument("Unsupported dtype"); + } + + // Get raw data pointers + auto& input_data = input_spec.get_float_data(); + + // Extract the randomized scale and zero point values (following + // q8csw_conv2d.cpp pattern) + float scale = scale_spec.get_float_value(); + int32_t zero_point = zero_point_spec.get_int_value(); + int32_t quant_min = -128; + int32_t quant_max = 127; + + // Prepare output data + auto& ref_data = output_spec.get_ref_float_data(); + int64_t num_elements = N * C * H * W; + ref_data.resize(num_elements); + + // Perform quantize-dequantize operation on each element + for (int64_t i = 0; i < num_elements; ++i) { + float input_val = input_data[i]; + + // Quantize: quantized = round(input / scale + zero_point) + float quantized_float = std::round(input_val / scale) + zero_point; + + // Clamp to quantization range + quantized_float = std::max(quantized_float, static_cast(quant_min)); + quantized_float = std::min(quantized_float, static_cast(quant_max)); + + int32_t quantized_int = static_cast(quantized_float); + + // Dequantize: output = (quantized - zero_point) * scale + float dequantized = (quantized_int - zero_point) * scale; + + ref_data[i] = dequantized; + } +} + +int main(int argc, char* argv[]) { + set_debugging(false); + set_print_output(false); + set_print_latencies(false); + set_use_gpu_timestamps(true); + + print_performance_header(); + std::cout << "QDQ8TA Conv2D Operation Prototyping Framework" << std::endl; + print_separator(); + + ReferenceComputeFunc ref_fn = qdq8ta_conv2d_reference_impl; + + auto results = execute_test_cases( + generate_qdq8ta_conv2d_test_cases, "QDQ8TAConv2D", 0, 1, ref_fn); + + return 0; +} diff --git a/backends/vulkan/test/custom_ops/targets.bzl b/backends/vulkan/test/custom_ops/targets.bzl index 3162857c2d3..1d1b1fe79bd 100644 --- a/backends/vulkan/test/custom_ops/targets.bzl +++ b/backends/vulkan/test/custom_ops/targets.bzl @@ -97,3 +97,4 @@ def define_common_targets(is_fbcode = False): define_custom_op_test_binary("q8csw_conv2d") define_custom_op_test_binary("choose_qparams_per_row") define_custom_op_test_binary("q4gsw_linear") + define_custom_op_test_binary("qdq8ta_conv2d_activations")