From 47a8a9f1e6358dbbcb139655e0c718be4da0d07e Mon Sep 17 00:00:00 2001 From: ssjia Date: Sun, 7 Sep 2025 22:10:00 -0400 Subject: [PATCH] [ET-VK] Fast path for choose_qparams Pull Request resolved: https://github.com/pytorch/executorch/pull/14019 The current implementations of `choose_qparams` are too slow to be practically usable. As a temporary workaround to unblock LLM optimizations, this diff/PR introduces a fast path for computing per-channel quantization parameters for 2D matrices in the form of the choose_qparams_per_row shader. ghstack-source-id: 308092877 @exported-using-ghexport Differential Revision: [D81800024](https://our.internmc.facebook.com/intern/diff/D81800024/) --- backends/vulkan/op_registry.py | 15 +- .../vulkan/runtime/graph/ComputeGraph.cpp | 51 ++- backends/vulkan/runtime/graph/ComputeGraph.h | 30 +- .../ops/glsl/choose_qparams_per_row.glsl | 184 +++++++++ .../ops/glsl/choose_qparams_per_row.yaml | 23 ++ .../runtime/graph/ops/impl/ChooseQParams.cpp | 189 +++++++++ .../custom_ops/choose_qparams_per_row.cpp | 363 ++++++++++++++++++ backends/vulkan/test/custom_ops/targets.bzl | 1 + 8 files changed, 848 insertions(+), 8 deletions(-) create mode 100644 backends/vulkan/runtime/graph/ops/glsl/choose_qparams_per_row.glsl create mode 100644 backends/vulkan/runtime/graph/ops/glsl/choose_qparams_per_row.yaml create mode 100644 backends/vulkan/test/custom_ops/choose_qparams_per_row.cpp diff --git a/backends/vulkan/op_registry.py b/backends/vulkan/op_registry.py index ade82bcde3b..8fbb41ed046 100644 --- a/backends/vulkan/op_registry.py +++ b/backends/vulkan/op_registry.py @@ -172,7 +172,6 @@ def register_affine_quantization_op(): @update_features( [ - exir_ops.edge.torchao.choose_qparams_affine.default, exir_ops.edge.quantized_decomposed.choose_qparams.tensor, exir_ops.edge.quantized_decomposed.choose_qparams_per_token_asymmetric.default, ] @@ -184,6 +183,20 @@ def register_torchao_quantization_op(): ) +@update_features( + exir_ops.edge.torchao.choose_qparams_affine.default, +) +def register_torchao_choose_qparams_affine(): + return OpFeatures( + inputs_storage=utils.CONTIGUOUS_ANY, + outputs_storage=[ + utils.CONTIGUOUS_BUFFER, # scales + utils.CONTIGUOUS_BUFFER, # zero_points + ], + supports_resize=True, + ) + + @update_features( [ exir_ops.edge.aten.add.Tensor, diff --git a/backends/vulkan/runtime/graph/ComputeGraph.cpp b/backends/vulkan/runtime/graph/ComputeGraph.cpp index f40a6b0f286..6609298b0d8 100644 --- a/backends/vulkan/runtime/graph/ComputeGraph.cpp +++ b/backends/vulkan/runtime/graph/ComputeGraph.cpp @@ -332,6 +332,16 @@ bool ComputeGraph::is_contiguous_buffer_tensor(const ValueRef idx) const { return is_contiguous(idx); } +bool ComputeGraph::is_contiguous_texture_tensor(const ValueRef idx) const { + if (!val_is_tensor(idx)) { + return false; + } + if (is_buffer_storage(idx)) { + return false; + } + return has_standard_axis_map(idx) && packed_dim_of(idx) == 0; +} + bool ComputeGraph::is_standard_channels_packed_texture_tensor( const ValueRef idx) const { if (!val_is_tensor(idx)) { @@ -343,15 +353,50 @@ bool ComputeGraph::is_standard_channels_packed_texture_tensor( return has_standard_axis_map(idx) && packed_dim_of(idx) == 2; } -bool ComputeGraph::is_standard_width_packed_texture_tensor( +bool ComputeGraph::is_2d_matrix(const ValueRef idx) const { + std::vector sizes = sizes_of(idx); + const size_t ndim = sizes.size(); + if (sizes.size() < 2) { + return false; + } + if (sizes.size() == 2) { + return true; + } + + // Check that outermost dims have size of 1 + for (int d = 0; d < ndim - 2; d++) { + if (sizes[d] != 1) { + return false; + } + } + + return true; +} + +bool ComputeGraph::is_vectorizable_contiguous_2d_matrix( const ValueRef idx) const { - if (!val_is_tensor(idx)) { + if (!is_2d_matrix(idx)) { return false; } if (is_buffer_storage(idx)) { + return is_contiguous_buffer_tensor(idx) && + size_at(-1, idx) % 4 == 0; + } + return is_contiguous_texture_tensor(idx); +} + +bool ComputeGraph::is_vectorizable_width_packed_tensor( + const ValueRef idx) const { + // Not a tensor - return false + if (!val_is_tensor(idx)) { return false; } - return has_standard_axis_map(idx) && packed_dim_of(idx) == 0; + if (is_buffer_storage(idx)) { + return is_contiguous_buffer_tensor(idx) && + size_at(-1, idx) % 4 == 0; + } + + return is_standard_channels_packed_texture_tensor(idx); } ValueRef ComputeGraph::add_tensor( diff --git a/backends/vulkan/runtime/graph/ComputeGraph.h b/backends/vulkan/runtime/graph/ComputeGraph.h index 4e9e2d36e1e..23b5517fd22 100644 --- a/backends/vulkan/runtime/graph/ComputeGraph.h +++ b/backends/vulkan/runtime/graph/ComputeGraph.h @@ -382,18 +382,40 @@ class ComputeGraph final { * 1. The value at `idx` is a tensor * 2. The tensor at `idx` has texture storage * 3. The texture backed tensor at `idx` has a standard axis mapping - * 4. The texture backed tensor at `idx` is channels packed + * 4. The texture backed tensor at `idx` is width packed */ - bool is_standard_channels_packed_texture_tensor(const ValueRef idx) const; + bool is_contiguous_texture_tensor(const ValueRef idx) const; /* * Checks that the following is true: * 1. The value at `idx` is a tensor * 2. The tensor at `idx` has texture storage * 3. The texture backed tensor at `idx` has a standard axis mapping - * 4. The texture backed tensor at `idx` is width packed + * 4. The texture backed tensor at `idx` is channels packed + */ + bool is_standard_channels_packed_texture_tensor(const ValueRef idx) const; + + /* + * Checks that the value at `idx` is either a 2D tensor, or if the tensor has + * more than 2 dims, the outermost dims have size of 1, i.e. can be squeezed + * to be a 2D tensor. + */ + bool is_2d_matrix(const ValueRef idx) const; + + /* + * Same as the above, but also requires that the tensor is a contiguous + * buffer with a width divisible by 4 or a standard width packed texture. + */ + bool is_vectorizable_contiguous_2d_matrix(const ValueRef idx) const; + + /* + * Checks that the following is true: + * 1. The value at `idx` is a tensor + * 2. The tensor at `idx` is width packed + * 3. The tensor at `idx` has a standard axis mapping or is a contiguous + * buffer */ - bool is_standard_width_packed_texture_tensor(const ValueRef idx) const; + bool is_vectorizable_width_packed_tensor(const ValueRef idx) const; inline bool val_is_view_of(const ValueRef maybe_view, const ValueRef base) const { diff --git a/backends/vulkan/runtime/graph/ops/glsl/choose_qparams_per_row.glsl b/backends/vulkan/runtime/graph/ops/glsl/choose_qparams_per_row.glsl new file mode 100644 index 00000000000..653b0a251c0 --- /dev/null +++ b/backends/vulkan/runtime/graph/ops/glsl/choose_qparams_per_row.glsl @@ -0,0 +1,184 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#version 450 core + +#define PRECISION ${PRECISION} +#define VEC4_T ${texel_load_type(DTYPE, STORAGE)} +#define T ${texel_load_component_type(DTYPE, STORAGE)} + +#define NUM_OUTPUTS_PER_WG ${NUM_OUTPUTS_PER_WG} +#define NUM_WORKERS_PER_OUTPUT ${NUM_WORKERS_PER_OUTPUT} + +// Maximum total threads in a work group +#define MAX_THREADS 256 + +${define_active_storage_type(STORAGE)} +${define_required_extensions("int8")} + +#extension GL_EXT_control_flow_attributes : require + +layout(std430) buffer; + +#include "common.glslh" + +${layout_declare_tensor(B, "w", "t_scales", "float", "buffer")} +${layout_declare_tensor(B, "w", "t_zps", "int", "buffer")} +${layout_declare_tensor(B, "r", "t_input", DTYPE, STORAGE, is_scalar_array=False)} + +${layout_declare_ubo(B, "ivec4", "input_sizes")} + +layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in; + +layout(push_constant) uniform PushConstants { + int quant_min; + int quant_max; +}; + +// Shared memory for cooperative min/max finding +shared T shared_min[NUM_OUTPUTS_PER_WG][NUM_WORKERS_PER_OUTPUT]; +shared T shared_max[NUM_OUTPUTS_PER_WG][NUM_WORKERS_PER_OUTPUT]; + +const float SMALL_SCALE_THRESHOLD = 6.1e-5; + +void calculate_scale_and_zero_point( + float min_val, + float max_val, + int qmin, + int qmax, + out float scale, + out int8_t zero_point) { + + // Extend the [min, max] interval to ensure it contains 0 + min_val = min(min_val, 0.0); + max_val = max(max_val, 0.0); + + // Calculate scale + scale = (max_val - min_val) / float(qmax - qmin); + + // Handle special cases for scale + if (scale == 0.0 || isinf(1.0 / scale)) { + scale = 0.1; + } + + // Cut off small scale + if (scale < SMALL_SCALE_THRESHOLD) { + float org_scale = scale; + scale = SMALL_SCALE_THRESHOLD; + // Adjust the min and max based on the new scale + if (min_val == 0.0) { + max_val = SMALL_SCALE_THRESHOLD * float(qmax - qmin); + } else if (max_val == 0.0) { + min_val = -SMALL_SCALE_THRESHOLD * float(qmax - qmin); + } else { + float amplifier = SMALL_SCALE_THRESHOLD / org_scale; + min_val *= amplifier; + max_val *= amplifier; + } + } + + // Zero-point computation + float zero_point_from_min = float(qmin) - min_val / scale; + float zero_point_from_max = float(qmax) - max_val / scale; + float zero_point_from_min_error = abs(float(qmin)) - abs(min_val / scale); + float zero_point_from_max_error = abs(float(qmax)) - abs(max_val / scale); + + float initial_zero_point = zero_point_from_min_error < zero_point_from_max_error + ? zero_point_from_min + : zero_point_from_max; + + // Nudge zero point to be an integer + int nudged_zero_point; + if (initial_zero_point < float(qmin)) { + nudged_zero_point = qmin; + } else if (initial_zero_point > float(qmax)) { + nudged_zero_point = qmax; + } else { + nudged_zero_point = int(round(initial_zero_point)); + } + + zero_point = int8_t(nudged_zero_point); +} + +#ifdef USING_BUFFER + +VEC4_T load_input_x4(const int x4, const int y, const int ntexels_x) { + return t_input[(y * ntexels_x) + x4]; +} + +#else // USING_TEXTURE + +VEC4_T load_input_x4(const int x4, const int y, const int ntexels_x) { + return texelFetch(t_input, ivec3(x4, y, 0), 0); +} + +#endif // USING_BUFFER + +void main() { + const int worker_id = int(gl_LocalInvocationID.x); + const int output_id = int(gl_LocalInvocationID.y); + + const int output_y = int(gl_GlobalInvocationID.y); + + if (output_y >= input_sizes.y) { + return; + } + + // Input is 2D tensor (height x width), width-packed + // Each channel corresponds to a row in the tensor + const int X4 = div_4(input_sizes.x); + + // Initialize thread-local min/max + float local_min = 1e30; + float local_max = -1e30; + + // Each thread processes elements along their assigned output_id with stride + // NUM_WORKERS_PER_OUTPUT + for (int x4 = worker_id; x4 < X4; x4 += NUM_WORKERS_PER_OUTPUT) { + VEC4_T in_texel = load_input_x4(x4, output_y, X4); + for (int i = 0; i < 4; i++) { + local_min = min(local_min, in_texel[i]); + local_max = max(local_max, in_texel[i]); + } + } + + // Store thread-local results in shared memory + shared_min[output_id][worker_id] = local_min; + shared_max[output_id][worker_id] = local_max; + + memoryBarrierShared(); + barrier(); + + // Tree reduction to compute the overall result + for (int i = NUM_WORKERS_PER_OUTPUT / 2; i > 0; i >>= 1) { + if (worker_id < i) { + shared_min[output_id][worker_id] = min( + shared_min[output_id][worker_id], + shared_min[output_id][worker_id + i]); + shared_max[output_id][worker_id] = max( + shared_max[output_id][worker_id], + shared_max[output_id][worker_id + i]); + } + memoryBarrierShared(); + barrier(); + } + + // Only first thread will write out result + if (worker_id == 0) { + local_min = shared_min[output_id][0]; + local_max = shared_max[output_id][0]; + + float scale; + int8_t zero_point; + calculate_scale_and_zero_point( + local_min, local_max, quant_min, quant_max, scale, zero_point); + + t_scales[output_y] = scale; + t_zps[output_y] = zero_point; + } +} diff --git a/backends/vulkan/runtime/graph/ops/glsl/choose_qparams_per_row.yaml b/backends/vulkan/runtime/graph/ops/glsl/choose_qparams_per_row.yaml new file mode 100644 index 00000000000..3608f7193bf --- /dev/null +++ b/backends/vulkan/runtime/graph/ops/glsl/choose_qparams_per_row.yaml @@ -0,0 +1,23 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +choose_qparams_per_row: + parameter_names_with_default_values: + DTYPE: float + STORAGE: texture3d + NUM_OUTPUTS_PER_WG: 1 + NUM_WORKERS_PER_OUTPUT: 64 + generate_variant_forall: + STORAGE: + - VALUE: texture3d + - VALUE: buffer + DTYPE: + - VALUE: float + shader_variants: + - NAME: choose_qparams_per_row_o1w64 + - NAME: choose_qparams_per_row_o4w16 + NUM_OUTPUTS_PER_WG: 4 + NUM_WORKERS_PER_OUTPUT: 16 diff --git a/backends/vulkan/runtime/graph/ops/impl/ChooseQParams.cpp b/backends/vulkan/runtime/graph/ops/impl/ChooseQParams.cpp index 2cf837fa89c..0d0be08bb38 100644 --- a/backends/vulkan/runtime/graph/ops/impl/ChooseQParams.cpp +++ b/backends/vulkan/runtime/graph/ops/impl/ChooseQParams.cpp @@ -14,6 +14,26 @@ namespace vkcompute { +void resize_choose_qparams_per_row( + ComputeGraph* graph, + const std::vector& args, + const std::vector& extra_args) { + (void)extra_args; + + ValueRef input_scales = args.at(0).refs.at(0); + ValueRef input_zeros = args.at(0).refs.at(1); + ValueRef input = args.at(1).refs.at(0); + + std::vector new_sizes = graph->sizes_of(input_scales); + const size_t ndim = new_sizes.size(); + + const int64_t input_height = graph->size_at(-2, input); + new_sizes.at(ndim - 1) = input_height; + + graph->virtual_resize(input_scales, new_sizes); + graph->virtual_resize(input_zeros, new_sizes); +} + utils::uvec3 choose_qparams_pick_global_wg_size( ComputeGraph* graph, const vkapi::ShaderInfo& shader, @@ -158,6 +178,64 @@ utils::uvec3 choose_qparams_block_wise_pick_local_wg_size( } } +vkapi::ShaderInfo pick_choose_qparams_per_row_shader( + ComputeGraph* graph, + const std::vector& args, + const std::vector& resize_args) { + (void)resize_args; + + const ValueRef input = args.at(1).refs.at(0); + + // number of output channels + const int64_t width = graph->size_at(-1, input); + const int64_t height = graph->size_at(-2, input); + + std::string kernel_name = "choose_qparams_per_row"; + if (width > 256 || height == 1) { + kernel_name += "_o1w64"; + } else { + kernel_name += "_o4w16"; + } + add_storage_type_suffix(kernel_name, graph->storage_type_of(input)); + add_dtype_suffix(kernel_name, graph->dtype_of(input)); + + return VK_KERNEL_FROM_STR(kernel_name); +} + +utils::uvec3 pick_choose_qparams_per_row_global_wg_size( + ComputeGraph* graph, + const vkapi::ShaderInfo& shader, + const std::vector& args, + const std::vector& resize_args) { + (void)shader; + (void)resize_args; + + const ValueRef input = args.at(1).refs.at(0); + const uint32_t height = graph->size_at(-2, input); + return {1u, height, 1u}; +} + +utils::uvec3 pick_choose_qparams_per_row_local_wg_size( + ComputeGraph* graph, + const vkapi::ShaderInfo& shader, + const utils::uvec3& global_workgroup_size, + const std::vector& args, + const std::vector& resize_args) { + (void)global_workgroup_size; + (void)args; + (void)resize_args; + + uint32_t outputs_per_wg = 1u; + uint32_t workers_per_output = 64u; + + if (shader.kernel_name.find("o4w16") != std::string::npos) { + outputs_per_wg = 4u; + workers_per_output = 16u; + } + + return {workers_per_output, outputs_per_wg, 1u}; +} + void add_choose_qparams_tensor_node( ComputeGraph& graph, const ValueRef& input, @@ -312,6 +390,57 @@ void add_choose_qparams_per_token_asymmetric_node( nullptr)); } +void add_choose_qparams_per_row_node( + ComputeGraph& graph, + const ValueRef& input, + const ValueRef& quant_min, + const ValueRef& quant_max, + const ValueRef& input_scales, + const ValueRef& input_zps) { + int32_t quant_min_val = -128; + int32_t quant_max_val = 127; + + // Int8 range by default + if (graph.val_is_none(quant_min)) { + quant_min_val = -128; + } else { + quant_min_val = graph.extract_scalar(quant_min); + } + + // Int8 range by default + if (graph.val_is_none(quant_min)) { + quant_max_val = 127; + } else { + quant_max_val = graph.extract_scalar(quant_max); + } + + vkapi::ParamsBindList param_ubos = { + graph.sizes_ubo(input), + }; + std::vector push_constants = { + PushConstantDataInfo(&quant_min_val, sizeof(int32_t)), + PushConstantDataInfo(&quant_max_val, sizeof(int32_t)), + }; + + graph.execute_nodes().emplace_back(new DynamicDispatchNode( + graph, + pick_choose_qparams_per_row_shader, + pick_choose_qparams_per_row_global_wg_size, + pick_choose_qparams_per_row_local_wg_size, + // Inputs and Outputs + {{{input_scales, input_zps}, vkapi::kWrite}, {input, vkapi::kRead}}, + // Shader param buffers + param_ubos, + // Push Constants + push_constants, + // Specialization Constants + {}, + // Resize Args + {}, + // Resizing Logic + resize_choose_qparams_per_row)); +} + void add_choose_qparams_block_wise_node( ComputeGraph& graph, ValueRef input, @@ -527,6 +656,32 @@ void choose_qparams_per_token_asymmetric_impl( graph, input, scale_out, zero_point_out); } +bool can_use_choose_qparams_per_row( + ComputeGraph& graph, + const ValueRef input, + const ValueRef block_size, + const ValueRef input_zero_point) { + if (!graph.is_vectorizable_contiguous_2d_matrix(input)) { + return false; + } + + std::vector input_sizes = graph.sizes_of(input); + const IntListPtr block_size_vals = graph.get_int_list(block_size); + const size_t ndim = block_size_vals->size(); + + // Check for per y - dim quantization + if (utils::val_at(-1, input_sizes) != utils::val_at(-1, *block_size_vals)) { + return false; + } + + for (int d = 0; d < ndim - 1; ++d) { + if (block_size_vals->at(d) != 1) { + return false; + } + } + return true; +} + void choose_qparams_affine_impl( ComputeGraph& graph, const std::vector& args) { @@ -556,6 +711,13 @@ void choose_qparams_affine_impl( zero_point_out = out_tuple->at(1); } + // Use fast path if certain conditions are met + if (can_use_choose_qparams_per_row( + graph, input, block_size, zero_point_out)) { + return add_choose_qparams_per_row_node( + graph, input, quant_min, quant_max, scale_out, zero_point_out); + } + // Check tensor types VK_CHECK_COND(graph.val_is_tensor(input)); VK_CHECK_COND(graph.val_is_tensor(scale_out)); @@ -611,6 +773,30 @@ void choose_qparams_affine_impl( zero_point_out); } +void choose_qparams_per_row( + ComputeGraph& graph, + const std::vector& args) { + int arg_idx = 0; + const ValueRef input = args[arg_idx++]; + const ValueRef quant_min = args[arg_idx++]; + const ValueRef quant_max = args[arg_idx++]; + const ValueRef input_scales = args[arg_idx++]; + const ValueRef input_zps = args[arg_idx++]; + + // ValueRef scale_out = kDummyValueRef; + // ValueRef zero_point_out = kDummyValueRef; + // + // { + // const ValueListPtr out_tuple = graph.get_value_list(out_tuple_ref); + // scale_out = out_tuple->at(0); + // zero_point_out = out_tuple->at(1); + // } + // + + add_choose_qparams_per_row_node( + graph, input, quant_min, quant_max, input_scales, input_zps); +} + REGISTER_OPERATORS { VK_REGISTER_OP( quantized_decomposed.choose_qparams.tensor, choose_qparams_tensor_impl); @@ -618,6 +804,9 @@ REGISTER_OPERATORS { quantized_decomposed.choose_qparams_per_token_asymmetric.default, choose_qparams_per_token_asymmetric_impl); + // Register the per-channel quantization operator + VK_REGISTER_OP(etvk.choose_qparams_per_row.default, choose_qparams_per_row); + // TorchAO affine choose_qparams operators VK_REGISTER_OP( torchao.choose_qparams_affine.default, choose_qparams_affine_impl); diff --git a/backends/vulkan/test/custom_ops/choose_qparams_per_row.cpp b/backends/vulkan/test/custom_ops/choose_qparams_per_row.cpp new file mode 100644 index 00000000000..aa2b21feab8 --- /dev/null +++ b/backends/vulkan/test/custom_ops/choose_qparams_per_row.cpp @@ -0,0 +1,363 @@ +// Copyright (c) Meta Platforms, Inc. and affiliates. +// All rights reserved. +// +// This source code is licensed under the BSD-style license found in the +// LICENSE file in the root directory of this source tree. + +#include +#include +#include +#include +#include +#include +#include "utils.h" + +#include + +using namespace executorch::vulkan::prototyping; +using namespace vkcompute; + +static constexpr int64_t kRefDimSizeLimit = 2050; +static constexpr float SMALL_SCALE_THRESHOLD = 6.1e-5f; + +// ChooseQParams configuration struct +struct ChooseQParamsConfig { + int64_t num_channels; // Height dimension (number of channels) + int64_t channel_size; // Width dimension (size per channel) + int32_t quant_min = -128; + int32_t quant_max = 127; + std::string test_case_name = "placeholder"; + std::string op_name = "choose_qparams_per_row"; +}; + +// Utility function to create a test case from a ChooseQParamsConfig +TestCase create_test_case_from_config( + const ChooseQParamsConfig& config, + utils::StorageType storage_type, + vkapi::ScalarType input_dtype) { + TestCase test_case; + + // Create a descriptive name for the test case + std::string storage_str = + (storage_type == utils::kTexture3D) ? "Texture3D" : "Buffer"; + std::string dtype_str = (input_dtype == vkapi::kFloat) ? "Float" : "Half"; + + std::string test_name = + config.test_case_name + "_" + storage_str + "_" + dtype_str; + test_case.set_name(test_name); + + // Set the operator name for the test case + std::string operator_name = "etvk." + config.op_name + ".default"; + test_case.set_operator_name(operator_name); + + // Input tensor (float) - [num_channels, channel_size] + std::vector input_size = {config.num_channels, config.channel_size}; + ValueSpec input_tensor( + input_size, + input_dtype, + storage_type, + utils::kWidthPacked, + DataGenType::RANDOM); + + if (debugging()) { + print_valuespec_data(input_tensor, "input_tensor"); + } + + // Quantization parameters + ValueSpec quant_min(config.quant_min); + ValueSpec quant_max(config.quant_max); + + // Output scale tensor (float) - [num_channels] + ValueSpec scale_out( + {config.num_channels}, + vkapi::kFloat, + utils::kBuffer, // Always buffer as per requirement + utils::kWidthPacked, + DataGenType::ZEROS); + + // Output zero_point tensor (int8) - [num_channels] + ValueSpec zero_point_out( + {config.num_channels}, + vkapi::kChar, // int8 for quantized zero point + utils::kBuffer, // Always buffer as per requirement + utils::kWidthPacked, + DataGenType::ZEROS); + + // Add all specs to test case + test_case.add_input_spec(input_tensor); + test_case.add_input_spec(quant_min); + test_case.add_input_spec(quant_max); + test_case.add_output_spec(scale_out); + test_case.add_output_spec(zero_point_out); + + return test_case; +} + +// CPU reference implementation matching the behavior from op_choose_qparams.cpp +void calculate_scale_and_zero_point_reference( + float min_val, + float max_val, + int32_t qmin, + int32_t qmax, + float& scale, + int32_t& zero_point) { + // Extend the [min, max] interval to ensure that it contains 0 + min_val = std::min(min_val, 0.0f); + max_val = std::max(max_val, 0.0f); + + // Use double precision for intermediate computation but use single precision + // in final number to reflect the actual number used during quantization. + double scale_double = + (static_cast(max_val) - min_val) / (qmax - qmin); + + // If scale is 0 or too small so its reciprocal is infinity, we arbitrary + // adjust the scale to 0.1 . We want to avoid scale's reciprocal being + // infinity because some of fbgemm code pre-computes scale's reciprocal to do + // multiplication instead of division in the time critical part of code. + if (static_cast(scale_double) == 0.0f || + std::isinf(1.0f / static_cast(scale_double))) { + scale_double = 0.1; + } + + // Cut off small scale + if (scale_double < SMALL_SCALE_THRESHOLD) { + float org_scale = static_cast(scale_double); + scale_double = SMALL_SCALE_THRESHOLD; + // Adjust the min and max based on the new scale + if (min_val == 0.0f) { + max_val = SMALL_SCALE_THRESHOLD * (qmax - qmin); + } else if (max_val == 0.0f) { + min_val = -SMALL_SCALE_THRESHOLD * (qmax - qmin); + } else { + float amplifier = SMALL_SCALE_THRESHOLD / org_scale; + min_val *= amplifier; + max_val *= amplifier; + } + } + + // Zero-point computation. + // First the initial floating-point computation. The zero-point can be + // determined from solving an affine equation for any known pair + // (real value, corresponding quantized value). + // We know two such pairs: (rmin, qmin) and (rmax, qmax). + // The arithmetic error on the zero point computed from either pair + // will be roughly machine_epsilon * (sum of absolute values of terms) + // so we want to use the variant that adds the smaller terms. + double zero_point_from_min = qmin - min_val / scale_double; + double zero_point_from_max = qmax - max_val / scale_double; + double zero_point_from_min_error = + std::abs(qmin) - std::abs(min_val / scale_double); + double zero_point_from_max_error = + std::abs(qmax) - std::abs(max_val / scale_double); + double initial_zero_point = + zero_point_from_min_error < zero_point_from_max_error + ? zero_point_from_min + : zero_point_from_max; + + // Now we need to nudge the zero point to be an integer + // (our zero points are integer, and this is motivated by the requirement + // to be able to represent the real value "0" exactly as a quantized value, + // which is required in multiple places, for example in Im2col with zero + // padding). + int32_t nudged_zero_point = 0; + if (initial_zero_point < qmin) { + nudged_zero_point = qmin; + } else if (initial_zero_point > qmax) { + nudged_zero_point = qmax; + } else { + nudged_zero_point = + static_cast(nearbyint(static_cast(initial_zero_point))); + } + + scale = static_cast(scale_double); + zero_point = nudged_zero_point; +} + +// Generate easy test cases for choose_qparams_per_channel operation (for +// debugging) +std::vector generate_choose_qparams_per_channel_easy_cases() { + std::vector test_cases; + + // Single simple configuration for debugging + int num_channels = 4; + int channel_size = 8; + + ChooseQParamsConfig config = { + num_channels, // num_channels + channel_size, // channel_size + -128, // quant_min + 127, // quant_max + "simple", // test_case_name + }; + + // Test with both storage types + std::vector storage_types = { + utils::kTexture3D, utils::kBuffer}; + std::vector float_types = {vkapi::kFloat}; + + // Generate test cases for each combination + for (const auto& storage_type : storage_types) { + for (const auto& input_dtype : float_types) { + test_cases.push_back( + create_test_case_from_config(config, storage_type, input_dtype)); + } + } + + return test_cases; +} + +// Generate test cases for choose_qparams_per_channel operation +std::vector generate_choose_qparams_per_channel_test_cases() { + std::vector test_cases; + + std::vector configs = { + {4, 16}, + {8, 32}, + {16, 64}, + {32, 128}, + {64, 256}, + {128, 512}, + {1, 512}, + // Performance cases + {256, 1024}, + {512, 2048}, + {1, 2048}, + {1, 8096}, + }; + + // Test with different storage types + std::vector storage_types = { + utils::kTexture3D, utils::kBuffer}; + + for (auto config : configs) { + std::string prefix = (config.num_channels < kRefDimSizeLimit && + config.channel_size < kRefDimSizeLimit) + ? "correctness_" + : "performance_"; + std::string generated_test_case_name = prefix + + std::to_string(config.num_channels) + "_" + + std::to_string(config.channel_size); + + config.test_case_name = generated_test_case_name; + + for (const auto& storage_type : storage_types) { + test_cases.push_back( + create_test_case_from_config(config, storage_type, vkapi::kFloat)); + } + } + + return test_cases; +} + +// Reference implementation for choose_qparams_per_channel +void choose_qparams_per_channel_reference_impl(TestCase& test_case) { + int32_t idx = 0; + const ValueSpec& input_spec = test_case.inputs()[idx++]; + const ValueSpec& quant_min_spec = test_case.inputs()[idx++]; + const ValueSpec& quant_max_spec = test_case.inputs()[idx++]; + const ValueSpec& eps_spec = test_case.inputs()[idx++]; + const ValueSpec& dtype_spec = test_case.inputs()[idx++]; + (void)eps_spec; // Unused in reference implementation + (void)dtype_spec; // Unused in reference implementation + + // Extract output specifications + ValueSpec& scale_out_spec = test_case.outputs()[0]; + ValueSpec& zero_point_out_spec = test_case.outputs()[1]; + + // Get tensor dimensions + auto input_sizes = + input_spec.get_tensor_sizes(); // [num_channels, channel_size] + int64_t num_channels = input_sizes[0]; + int64_t channel_size = input_sizes[1]; + + // Skip for large tensors since computation time will be extremely slow + if (num_channels > kRefDimSizeLimit || channel_size > kRefDimSizeLimit) { + throw std::invalid_argument( + "One or more dimensions (num_channels, channel_size) exceed the allowed limit for reference implementation."); + } + + if (input_spec.dtype != vkapi::kFloat) { + throw std::invalid_argument("Unsupported dtype"); + } + + // Get raw data pointers + auto& input_data = input_spec.get_float_data(); + int32_t quant_min = quant_min_spec.get_int_value(); + int32_t quant_max = quant_max_spec.get_int_value(); + + // Prepare output data + auto& scale_ref_data = scale_out_spec.get_ref_float_data(); + auto& zero_point_ref_data = zero_point_out_spec.get_ref_int8_data(); + scale_ref_data.resize(num_channels); + zero_point_ref_data.resize(num_channels); + + // Process each channel + for (int64_t channel = 0; channel < num_channels; ++channel) { + // Find min and max for this channel + float min_val = std::numeric_limits::max(); + float max_val = std::numeric_limits::lowest(); + + for (int64_t i = 0; i < channel_size; ++i) { + int64_t input_idx = channel * channel_size + i; + float val = input_data[input_idx]; + min_val = std::min(min_val, val); + max_val = std::max(max_val, val); + } + + // Calculate scale and zero point for this channel + float scale; + int32_t zero_point; + calculate_scale_and_zero_point_reference( + min_val, max_val, quant_min, quant_max, scale, zero_point); + + // Store results (cast zero_point to int8) + scale_ref_data[channel] = scale; + zero_point_ref_data[channel] = static_cast(zero_point); + } +} + +void reference_impl(TestCase& test_case) { + choose_qparams_per_channel_reference_impl(test_case); +} + +int64_t choose_qparams_per_channel_flop_calculator(const TestCase& test_case) { + // Get input dimensions + const auto& input_sizes = test_case.inputs()[0].get_tensor_sizes(); + int64_t num_channels = input_sizes[0]; + int64_t channel_size = input_sizes[1]; + + // Calculate FLOPs for choose_qparams_per_channel operation + // Each channel requires: + // - Min/max finding: approximately 2 * channel_size comparisons + // - Scale calculation: ~5 operations (division, min/max operations) + // - Zero point calculation: ~10 operations (multiple arithmetic operations) + int64_t ops_per_channel = 2 * channel_size + 15; // Simplified estimate + + int64_t flop = num_channels * ops_per_channel; + + return flop; +} + +int main(int argc, char* argv[]) { + set_debugging(false); + set_print_output(false); + set_print_latencies(false); + set_use_gpu_timestamps(true); + + print_performance_header(); + std::cout << "Choose QParams Per Channel Operation Prototyping Framework" + << std::endl; + print_separator(); + + ReferenceComputeFunc ref_fn = reference_impl; + + auto results = execute_test_cases( + generate_choose_qparams_per_channel_test_cases, + choose_qparams_per_channel_flop_calculator, + "ChooseQParamsPerChannel", + 0, + 10, + ref_fn); + + return 0; +} diff --git a/backends/vulkan/test/custom_ops/targets.bzl b/backends/vulkan/test/custom_ops/targets.bzl index 4297565da80..5d99f90ec5a 100644 --- a/backends/vulkan/test/custom_ops/targets.bzl +++ b/backends/vulkan/test/custom_ops/targets.bzl @@ -95,3 +95,4 @@ def define_common_targets(is_fbcode = False): define_custom_op_test_binary("add") define_custom_op_test_binary("q8csw_linear") define_custom_op_test_binary("q8csw_conv2d") + define_custom_op_test_binary("choose_qparams_per_row")