diff --git a/backends/vulkan/runtime/graph/ops/glsl/choose_qparams_buffer.glsl b/backends/vulkan/runtime/graph/ops/glsl/choose_qparams_buffer.glsl deleted file mode 100644 index 7e21bcf0eba..00000000000 --- a/backends/vulkan/runtime/graph/ops/glsl/choose_qparams_buffer.glsl +++ /dev/null @@ -1,400 +0,0 @@ -/* - * Copyright (c) Meta Platforms, Inc. and affiliates. - * All rights reserved. - * - * This source code is licensed under the BSD-style license found in the - * LICENSE file in the root directory of this source tree. - */ - -#version 450 core - -#define PRECISION ${PRECISION} - -#define IN_T ${buffer_scalar_type(IN_DTYPE)} -#define SCALE_OUT_T ${buffer_scalar_type(SCALE_OUT_DTYPE)} -#define ZP_OUT_T ${buffer_scalar_type(ZP_OUT_DTYPE)} - -#define ${MODE} - -${define_active_storage_type("buffer")} -${define_required_extensions(IN_DTYPE)} -${define_required_extensions(SCALE_OUT_DTYPE)} -${define_required_extensions(ZP_OUT_DTYPE)} - -#extension GL_EXT_control_flow_attributes : require - -layout(std430) buffer; - -${layout_declare_tensor(B, "w", "t_scale", SCALE_OUT_DTYPE, "buffer")} -${layout_declare_tensor(B, "w", "t_zero_point", ZP_OUT_DTYPE, "buffer")} -${layout_declare_tensor(B, "r", "t_in", IN_DTYPE, "buffer")} - -$if MODE == "per_tensor": - layout(push_constant) uniform restrict Block { - int quant_min; - int quant_max; - float eps; - }; -$if MODE == "per_token": - layout(push_constant) uniform restrict Block { - int num_tokens; - int quant_min; - int quant_max; - }; -$if MODE == "block_wise": - layout(push_constant) uniform BlockPC { - ivec4 blockSize; // WHCN (>=1) - ivec4 numBlocks; // #blocks along W,H,C,N - ivec4 blockStride; // {1, #W, #W * #H, #W * #H * #C} - int mapping_type; // 0=ASYM, 1=SYM, 2=SYM_NO_CLIP - int quant_min; - int quant_max; - float eps; - }; - -${layout_declare_ubo(B, "ivec4", "t_in_sizes")} -${layout_declare_ubo(B, "ivec4", "t_in_strides")} -${layout_declare_ubo(B, "ivec4", "t_scale_sizes")} -${layout_declare_ubo(B, "ivec4", "t_scale_strides")} -${layout_declare_ubo(B, "ivec4", "t_zero_point_sizes")} -${layout_declare_ubo(B, "ivec4", "t_zero_point_strides")} - -#include "indexing_utils.h" -#include "choose_qparams.glslh" - -layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in; - -#define NWORKERS 64 - -// Shared memory for reduction - must match local work group size -shared float shared_min[NWORKERS]; -shared float shared_max[NWORKERS]; - -/* - Quantization Parameter Computation Shader (Buffer Storage) - This shader computes quantization parameters (scale and zero_point) for converting - floating-point tensors to n-bit integer representations while preserving the - original data range as much as possible. The computed parameters enable efficient - quantization by mapping the continuous floating-point range to discrete integer values. - - Important Considerations: - (+) The input tensor is assumed to be WIDTH_PACKED (i.e., contiguous in the last dimension) - - Workgroup Configuration: - - choose_qparams_per_tensor - This mode computes a single set of quantization parameters for the entire tensor. - Uses parallel reduction across all threads to find global min/max values. - - (*) global_wg_size: {1, 1, 1} (single workgroup processes entire tensor) - (*) local_wg_size: {64, 1, 1} (matches NWORKERS for shared memory) - - - choose_qparams_per_token - This mode computes separate quantization parameters for each token in the tensor. - Each workgroup processes one token independently to find token-specific min/max. - - (*) global_wg_size: {num_tokens, 1, 1} (one workgroup per token) - (*) local_wg_size: {1, 1, 1} (single thread per token) - - - choose_qparams_block_wise - This mode computes quantization parameters for each block of elements, allowing - fine-grained control over quantization granularity within the tensor. Each block - is processed independently to find its own min/max values and compute corresponding - scale and zero_point parameters. - - (*) global_wg_size: {nBlocks, 1u, 1u} (one workgroup per block) - (*) local_wg_size: {1, 1, 1} (single thread per block) - - Block-wise quantization supports multiple mapping types for scale/zero_point calculation: - - - mapping_type = 0 (ASYMMETRIC): - Uses asymmetric quantization where the full floating-point range [min, max] is - mapped to the quantized range [quant_min, quant_max]. This preserves the original - data distribution but may not center zero optimally. - - Calculation: - scale = (max - min) / (quant_max - quant_min) - zero_point = quant_min - round(min / scale) - - Example: For range [-3.5, 10.2] mapping to int4 [-8, 7]: - scale = (10.2 - (-3.5)) / (7 - (-8)) = 13.7 / 15 = 0.913 - zero_point = -8 - round(-3.5 / 0.913) = -8 - (-4) = -4 - - - mapping_type = 1 (SYMMETRIC): - Uses symmetric quantization where the range is centered around zero. The scale - is computed based on the maximum absolute value, ensuring zero is exactly - representable in the quantized domain. - - Calculation: - max_abs = max(abs(min), abs(max)) - scale = max_abs / ((quant_max - quant_min) / 2) - zero_point = (quant_max + quant_min + 1) / 2 // midpoint - - Example: For range [-3.5, 10.2] mapping to int4 [-8, 7]: - max_abs = max(3.5, 10.2) = 10.2 - scale = 10.2 / ((7 - (-8)) / 2) = 10.2 / 7.5 = 1.36 - zero_point = (-8 + 7 + 1) / 2 = 0 - - - mapping_type = 2 (SYMMETRIC_NO_CLIPPING_ERR): - A variant of symmetric quantization that minimizes clipping errors by computing - separate scales for positive and negative ranges, then using the maximum. This - reduces quantization error on the dominant range while ensuring no values are - clipped. - - Calculation: - smin = abs(min) / abs(quant_min) // scale for negative range - smax = max / quant_max // scale for positive range - scale = max(smin, smax) // use larger scale to avoid clipping - zero_point = (quant_max + quant_min + 1) / 2 // midpoint - - Example: For range [-3.5, 10.2] mapping to int4 [-8, 7]: - smin = 3.5 / 8 = 0.4375 - smax = 10.2 / 7 = 1.457 - scale = max(0.4375, 1.457) = 1.457 // use smax to avoid clipping positives - zero_point = (-8 + 7 + 1) / 2 = 0 - - Tree Reduction Algorithm for Min/Max Finding: - The shader uses a parallel tree reduction algorithm to efficiently find minimum and - maximum values across multiple threads. This approach reduces the number of memory - accesses and synchronization points compared to sequential scanning. - - Example with 8 threads processing values [10, 1, 8, 1, 0, 2, 3, 5]: - - Step 1 - Initial Population: - Each thread loads its assigned value into shared memory arrays. - shared_min: | 10 | 1 | 8 | 1 | 0 | 2 | 3 | 5 | - shared_max: | 10 | 1 | 8 | 1 | 0 | 2 | 3 | 5 | - Thread ID: | 0 | 1 | 2 | 3 | 4 | 5 | 6 | 7 | - - Step 2 - Stride 1 (Compare Adjacent Pairs): - Threads 0,2,4,6 compare with threads 1,3,5,7 respectively. - shared_min: | 1 | | 1 | | 0 | | 3 | | (min(10,1), min(8,1), min(0,2), min(3,5)) - shared_max: | 10 | | 8 | | 2 | | 5 | | (max(10,1), max(8,1), max(0,2), max(3,5)) - Active: | 0 | | 2 | | 4 | | 6 | | - - Step 3 - Stride 2 (Compare Pairs of Pairs): - Threads 0,4 compare with threads 2,6 respectively. - shared_min: | 1 | | | | 0 | | | | (min(1,1), min(0,3)) - shared_max: | 10 | | | | 5 | | | | (max(10,8), max(2,5)) - Active: | 0 | | | | 4 | | | | - - Step 4 - Stride 4 (Final Comparison): - Thread 0 compares with thread 4 to get final result. - shared_min: | 0 | | | | | | | | (min(1,0) = 0) - shared_max: | 10 | | | | | | | | (max(10,5) = 10) - Active: | 0 | | | | | | | | - - Final Result: global_min = 0, global_max = 10 (stored in shared_min[0], shared_max[0]) - - The tree reduction completes in log_2(N) steps where N is the number of threads, - providing O(log N) time complexity instead of O(N) for sequential reduction. - - Quantization Parameter Calculation: - Once min/max values are determined, the shader computes: - - scale = (max - min) / (quant_max - quant_min) - - zero_point = quantization offset to map floating-point zero to integer range - - Mode-Specific Behavior: - - Per-Tensor: Single workgroup with strided access across entire tensor - - Per-Token: Multiple workgroups, each processing one token independently - - Block-Wise: Each thread processes assigned blocks using nested loops over block dimensions -*/ - -#ifdef per_tensor - -void choose_qparams_per_tensor() { - uint global_id = gl_GlobalInvocationID.x; - uint local_id = gl_LocalInvocationID.x; - uint total_threads = gl_NumWorkGroups.x * gl_WorkGroupSize.x; - - uint total_elements = uint(t_in_sizes.x * t_in_sizes.y * t_in_sizes.z * t_in_sizes.w); - - // Each thread processes multiple elements with stride - float thread_min = 1.0/0.0; // +infinity - float thread_max = -1.0/0.0; // -infinity - bool found_valid = false; - - for (uint i = global_id; i < total_elements; i += total_threads) { - float val = t_in[i]; - if (!isnan(val) && !isinf(val)) { - if (!found_valid) { - thread_min = val; - thread_max = val; - found_valid = true; - } else { - thread_min = min(thread_min, val); - thread_max = max(thread_max, val); - } - } - } - - // Intra-group reduction using shared memory - shared_min[local_id] = thread_min; - shared_max[local_id] = thread_max; - barrier(); - - // Tree reduction within work group - for (uint stride = gl_WorkGroupSize.x / 2; stride > 0; stride >>= 1) { - if (local_id < stride) { - float other_min = shared_min[local_id + stride]; - float other_max = shared_max[local_id + stride]; - - if (!isinf(other_min) && (isinf(shared_min[local_id]) || other_min < shared_min[local_id])) { - shared_min[local_id] = other_min; - } - if (!isinf(other_max) && (isinf(shared_max[local_id]) || other_max > shared_max[local_id])) { - shared_max[local_id] = other_max; - } - } - barrier(); - } - - // Final result calculation (single workgroup only) - if (local_id == 0) { - float global_min = shared_min[0]; - float global_max = shared_max[0]; - - float scale_val; - int zero_point_val; - // Use default values: mapping_type=0 (ASYMMETRIC), eps from push constant - calc_scale_zp(global_min, global_max, quant_min, quant_max, 0, eps, scale_val, zero_point_val); - - t_scale[0] = SCALE_OUT_T(scale_val); - t_zero_point[0] = ZP_OUT_T(zero_point_val); - } -} - -#elif defined(per_token) - -void choose_qparams_per_token() { - uint total_elements = uint(t_in_sizes.x * t_in_sizes.y * t_in_sizes.z * t_in_sizes.w); - uint token_size = total_elements / uint(num_tokens); - - const uint TOTAL_TOKENS = uint(num_tokens); - - /* each invocation handles token-ids: id, id+STRIDE, id+2·STRIDE … */ - const uint STRIDE = gl_WorkGroupSize.x * gl_NumWorkGroups.x; - for (uint token_id = gl_GlobalInvocationID.x; token_id < TOTAL_TOKENS; token_id += STRIDE) { - // Calculate the start and end indices for this token - uint token_start = token_id * token_size; - uint token_end = token_start + token_size; - - // Each thread processes the entire token - float lo = 1.0/0.0; // +INF - float hi = -1.0/0.0; // -INF - bool found_valid = false; - - // Process all elements in this token - for (uint i = token_start; i < token_end; i++) { - float val = t_in[i]; - if (!isnan(val) && !isinf(val)) { - if (!found_valid) { - lo = hi = val; - found_valid = true; - } else { - lo = min(lo, val); - hi = max(hi, val); - } - } - } - - if (!found_valid) { - // If no valid values were found, use default values - lo = 0.0; - hi = 0.0; - } - - // Calculate scale and zero point directly - float scale_val; - int zero_point_val; - // Use default values: mapping_type=0 (ASYMMETRIC), eps=1e-5 - calc_scale_zp(lo, hi, quant_min, quant_max, 0, 1e-5, scale_val, zero_point_val); - - // Write results - t_scale[token_id] = SCALE_OUT_T(scale_val); - t_zero_point[token_id] = ZP_OUT_T(zero_point_val); - } -} - -#elif defined(block_wise) - -ivec4 block_id_to_coord(uint bid) { - ivec4 bc; - bc.w = int(bid) / blockStride.w; - - int r = int(bid) - bc.w * blockStride.w; - bc.z = r / blockStride.z; - - r -= bc.z * blockStride.z; - bc.y = r / blockStride.y; - - r -= bc.y * blockStride.y; - bc.x = r; - return bc; -} - -void choose_qparams_block_wise() { - const uint TOTAL_BLOCKS = uint(numBlocks.x * numBlocks.y * numBlocks.z * numBlocks.w); - - // each invocation handles block-ids: id, id+STRIDE, id+2·STRIDE - const uint STRIDE = gl_WorkGroupSize.x * gl_NumWorkGroups.x; - for (uint block_id = gl_GlobalInvocationID.x; block_id < TOTAL_BLOCKS; block_id += STRIDE) { - // block -> WHCN coordinate - ivec4 bc = block_id_to_coord(block_id); - ivec4 blockStart = bc * blockSize; // first element (inclusive) - ivec4 blockEnd = blockStart + blockSize; // last element (exclusive) - - // min / max scan over the block - float lo = 1.0/0.0; // +INF - float hi = -1.0/0.0; // -INF - bool found_valid = false; - - // Calculate actual block dimensions - ivec4 actualBlockSize = blockEnd - blockStart; - int blockElements = actualBlockSize.x * actualBlockSize.y * actualBlockSize.z * actualBlockSize.w; - - // Linear iteration over block elements - for (int elemIdx = 0; elemIdx < blockElements; ++elemIdx) { - // Convert linear index to 4D coordinates within block - int remaining = elemIdx; - int dn = remaining / (actualBlockSize.x * actualBlockSize.y * actualBlockSize.z); - remaining -= dn * (actualBlockSize.x * actualBlockSize.y * actualBlockSize.z); - int dc = remaining / (actualBlockSize.x * actualBlockSize.y); - remaining -= dc * (actualBlockSize.x * actualBlockSize.y); - int dh = remaining / actualBlockSize.x; - int dw = remaining - dh * actualBlockSize.x; - - ivec4 tidx = blockStart + ivec4(dw, dh, dc, dn); - uint idx = tidx_to_bufi(tidx, t_in_strides); - float v = t_in[idx]; - - if (!isnan(v) && !isinf(v)) { - if (!found_valid) { - lo = hi = v; - found_valid = true; - } else { - lo = min(lo, v); - hi = max(hi, v); - } - } - } - - // Handle the case where no valid values were found in the block - if (!found_valid) { - lo = 0.0; - hi = 0.0; - } - - float scale_val; - int zero_point_val; - calc_scale_zp(lo, hi, quant_min, quant_max, mapping_type, eps, scale_val, zero_point_val); - - t_scale[block_id] = SCALE_OUT_T(scale_val); - t_zero_point[block_id] = ZP_OUT_T(zero_point_val); - } -} - -#endif - -void main() { - choose_qparams_${MODE}(); -} diff --git a/backends/vulkan/runtime/graph/ops/glsl/choose_qparams_buffer.yaml b/backends/vulkan/runtime/graph/ops/glsl/choose_qparams_buffer.yaml deleted file mode 100644 index 8459b043baa..00000000000 --- a/backends/vulkan/runtime/graph/ops/glsl/choose_qparams_buffer.yaml +++ /dev/null @@ -1,22 +0,0 @@ -choose_qparams_buffer: - parameter_names_with_default_values: - IN_DTYPE: float - SCALE_OUT_DTYPE: float - ZP_OUT_DTYPE: int32 - MODE: per_tensor - generate_variant_forall: - IN_DTYPE: - - VALUE: float - SCALE_OUT_DTYPE: - - VALUE: float - ZP_OUT_DTYPE: - - VALUE: int32 - - VALUE: int8 - - VALUE: float - shader_variants: - - NAME: choose_qparams_tensor_buffer - MODE: per_tensor - - NAME: choose_qparams_per_token_asymmetric_buffer - MODE: per_token - - NAME: choose_qparams_block_wise_buffer - MODE: block_wise diff --git a/backends/vulkan/runtime/graph/ops/glsl/choose_qparams_texture.glsl b/backends/vulkan/runtime/graph/ops/glsl/choose_qparams_texture.glsl deleted file mode 100644 index a17a3ae41dd..00000000000 --- a/backends/vulkan/runtime/graph/ops/glsl/choose_qparams_texture.glsl +++ /dev/null @@ -1,533 +0,0 @@ -/* - * Copyright (c) Meta Platforms, Inc. and affiliates. - * All rights reserved. - * - * This source code is licensed under the BSD-style license found in the - * LICENSE file in the root directory of this source tree. - */ - -#version 450 core - -#define PRECISION ${PRECISION} - -#define IN_T ${buffer_scalar_type(IN_DTYPE)} -#define FVEC4_T ${texel_load_type(IN_DTYPE, "texture3d")} -#define SCALE_OUT_T ${buffer_scalar_type(SCALE_OUT_DTYPE)} -#define ZP_OUT_T ${buffer_scalar_type(ZP_OUT_DTYPE)} - -#define ${MODE} - -${define_active_storage_type("texture3d")} -${define_required_extensions(IN_DTYPE)} -${define_required_extensions(SCALE_OUT_DTYPE)} -${define_required_extensions(ZP_OUT_DTYPE)} - -#extension GL_EXT_control_flow_attributes : require - -layout(std430) buffer; - -$if MODE != "block_wise": - ${layout_declare_tensor(B, "w", "t_scale", SCALE_OUT_DTYPE, "texture3d")} - ${layout_declare_tensor(B, "w", "t_zero_point", ZP_OUT_DTYPE, "texture3d")} -$else: - ${layout_declare_tensor(B, "w", "t_scale", SCALE_OUT_DTYPE, "buffer")} - ${layout_declare_tensor(B, "w", "t_zero_point", ZP_OUT_DTYPE, "buffer")} - -${layout_declare_tensor(B, "r", "t_in", IN_DTYPE, "texture3d")} - -$if MODE == "per_tensor": - layout(push_constant) uniform restrict Block { - int quant_min; - int quant_max; - float eps; - }; -$if MODE == "per_token": - layout(push_constant) uniform restrict Block { - int num_tokens; - int quant_min; - int quant_max; - }; -$if MODE == "block_wise": - layout(push_constant) uniform BlockPC { - ivec4 blockSize; // WHCN (>=1) - ivec4 numBlocks; // #blocks along W,H,C,N - ivec4 blockStride; // {1, #W, #W * #H, #W * #H * #C} - int mapping_type; // 0=ASYM, 1=SYM, 2=SYM_NO_CLIP - int quant_min; - int quant_max; - float eps; - }; - -${layout_declare_ubo(B, "ivec3", "t_in_limits")} -$if MODE != "block_wise": - ${layout_declare_ubo(B, "ivec3", "t_scale_limits")} - ${layout_declare_ubo(B, "ivec3", "t_zero_point_limits")} -$else: - ${layout_declare_ubo(B, "ivec4", "t_scale_sizes")} - ${layout_declare_ubo(B, "ivec4", "t_scale_strides")} - ${layout_declare_ubo(B, "ivec4", "t_zero_point_sizes")} - ${layout_declare_ubo(B, "ivec4", "t_zero_point_strides")} - - -#include "indexing_utils.h" -#include "choose_qparams.glslh" - -layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in; - -#define NWORKERS 64 - -// Shared memory for reduction - must match local work group size -shared float shared_min[NWORKERS]; -shared float shared_max[NWORKERS]; - -/*/* - Quantization Parameter Computation Shader (Buffer Storage) - This shader computes quantization parameters (scale and zero_point) for converting - floating-point tensors to n-bit integer representations while preserving the - original data range as much as possible. The computed parameters enable efficient - quantization by mapping the continuous floating-point range to discrete integer values. - - Important Considerations: - (+) The input tensor is assumed to be WIDTH_PACKED (i.e., contiguous in the last dimension) - - Workgroup Configuration: - - choose_qparams_per_tensor - This mode computes a single set of quantization parameters for the entire tensor. - Uses parallel reduction across all threads to find global min/max values. - - (*) global_wg_size: default - (*) local_wg_size: default - - - choose_qparams_per_token - This mode computes separate quantization parameters for each token in the tensor. - Each workgroup processes one token independently to find token-specific min/max. - - (*) global_wg_size: default - (*) local_wg_size: {1, 1, 1} - - - choose_qparams_block_wise - This mode computes quantization parameters for each block of elements, allowing - fine-grained control over quantization granularity within the tensor. Each block - is processed independently to find its own min/max values and compute corresponding - scale and zero_point parameters. - - NOTE: This mode currently only supports buffer storage for the output. - - (*) global_wg_size: {nBlocks, 1u, 1u} (one workgroup per block) - (*) local_wg_size: {1, 1, 1} (single thread per block) - - Tree Reduction Algorithm for Min/Max Finding: - The shader uses a parallel tree reduction algorithm to efficiently find minimum and - maximum values across multiple threads. This approach reduces the number of memory - accesses and synchronization points compared to sequential scanning. - - Example with 8 threads processing values [10, 1, 8, 1, 0, 2, 3, 5]: - - Step 1 - Initial Population: - Each thread loads its assigned value into shared memory arrays. - shared_min: | 10 | 1 | 8 | 1 | 0 | 2 | 3 | 5 | - shared_max: | 10 | 1 | 8 | 1 | 0 | 2 | 3 | 5 | - Thread ID: | 0 | 1 | 2 | 3 | 4 | 5 | 6 | 7 | - - Step 2 - Stride 1 (Compare Adjacent Pairs): - Threads 0,2,4,6 compare with threads 1,3,5,7 respectively. - shared_min: | 1 | | 1 | | 0 | | 3 | | (min(10,1), min(8,1), min(0,2), min(3,5)) - shared_max: | 10 | | 8 | | 2 | | 5 | | (max(10,1), max(8,1), max(0,2), max(3,5)) - Active: | 0 | | 2 | | 4 | | 6 | | - - Step 3 - Stride 2 (Compare Pairs of Pairs): - Threads 0,4 compare with threads 2,6 respectively. - shared_min: | 1 | | | | 0 | | | | (min(1,1), min(0,3)) - shared_max: | 10 | | | | 5 | | | | (max(10,8), max(2,5)) - Active: | 0 | | | | 4 | | | | - - Step 4 - Stride 4 (Final Comparison): - Thread 0 compares with thread 4 to get final result. - shared_min: | 0 | | | | | | | | (min(1,0) = 0) - shared_max: | 10 | | | | | | | | (max(10,5) = 10) - Active: | 0 | | | | | | | | - - Final Result: global_min = 0, global_max = 10 (stored in shared_min[0], shared_max[0]) - - The tree reduction completes in log_2(N) steps where N is the number of threads, - providing O(log N) time complexity instead of O(N) for sequential reduction. - - Quantization Parameter Calculation: - Once min/max values are determined, the shader computes: - - scale = (max - min) / (quant_max - quant_min) - - zero_point = quantization offset to map floating-point zero to integer range - - Mode-Specific Behavior: - - Per-Tensor: Single workgroup with strided access across entire tensor - - Per-Token: Multiple workgroups, each processing one token independently -*/ - -#ifdef per_tensor - -void choose_qparams_per_tensor() { - uint global_id = gl_GlobalInvocationID.x; - uint local_id = gl_LocalInvocationID.x; - uint group_id = gl_WorkGroupID.x; - uint total_threads = gl_NumWorkGroups.x * gl_WorkGroupSize.x; - - uint total_texels = uint(t_in_limits.x * t_in_limits.y * t_in_limits.z); - - // Each thread processes multiple texels with stride - float thread_min = 1.0/0.0; // +infinity - float thread_max = -1.0/0.0; // -infinity - bool found_valid = false; - - // Process texels with stride across all threads - for (uint texel_idx = global_id; texel_idx < total_texels; texel_idx += total_threads) { - // Convert linear texel index to 3D coordinates - uint z = texel_idx / uint(t_in_limits.x * t_in_limits.y); - uint remainder = texel_idx % uint(t_in_limits.x * t_in_limits.y); - uint y = remainder / uint(t_in_limits.x); - uint x = remainder % uint(t_in_limits.x); - ivec3 texel_pos = ivec3(int(x), int(y), int(z)); - - FVEC4_T texel_data = load_texel(t_in, texel_pos); - - // For texture storage, we assume width-packed (packed_dim = 0) - // Calculate number of valid elements in this texel (handle padding) - int packed_dim = 0; // Width dimension is packed - ivec4 sizes = ivec4(t_in_limits, 1); // Convert limits to sizes format - ivec4 tensor_coord = to_tensor_idx(texel_pos, sizes, packed_dim); - - // Calculate total tensor elements to determine padding - int total_elements = t_in_limits.x * t_in_limits.y * t_in_limits.z * 4; - int linear_tensor_idx = tensor_coord.x + tensor_coord.y * sizes.x + - tensor_coord.z * sizes.x * sizes.y; - int remaining_elements = total_elements - (linear_tensor_idx); - int valid_elements = min(4, remaining_elements); - - // Find min/max within this texel, considering only valid elements - if (valid_elements >= 1 && !isnan(texel_data.x) && !isinf(texel_data.x)) { - if (!found_valid) { - thread_min = texel_data.x; - thread_max = texel_data.x; - found_valid = true; - } else { - thread_min = min(thread_min, texel_data.x); - thread_max = max(thread_max, texel_data.x); - } - } - - if (valid_elements >= 2 && !isnan(texel_data.y) && !isinf(texel_data.y)) { - if (!found_valid) { - thread_min = texel_data.y; - thread_max = texel_data.y; - found_valid = true; - } else { - thread_min = min(thread_min, texel_data.y); - thread_max = max(thread_max, texel_data.y); - } - } - - if (valid_elements >= 3 && !isnan(texel_data.z) && !isinf(texel_data.z)) { - if (!found_valid) { - thread_min = texel_data.z; - thread_max = texel_data.z; - found_valid = true; - } else { - thread_min = min(thread_min, texel_data.z); - thread_max = max(thread_max, texel_data.z); - } - } - - if (valid_elements >= 4 && !isnan(texel_data.w) && !isinf(texel_data.w)) { - if (!found_valid) { - thread_min = texel_data.w; - thread_max = texel_data.w; - found_valid = true; - } else { - thread_min = min(thread_min, texel_data.w); - thread_max = max(thread_max, texel_data.w); - } - } - } - - // Intra-workgroup reduction using shared memory - shared_min[local_id] = thread_min; - shared_max[local_id] = thread_max; - barrier(); - - // Tree reduction within work group - for (uint stride = gl_WorkGroupSize.x / 2; stride > 0; stride >>= 1) { - if (local_id < stride) { - float other_min = shared_min[local_id + stride]; - float other_max = shared_max[local_id + stride]; - - if (!isinf(other_min) && (isinf(shared_min[local_id]) || other_min < shared_min[local_id])) { - shared_min[local_id] = other_min; - } - if (!isinf(other_max) && (isinf(shared_max[local_id]) || other_max > shared_max[local_id])) { - shared_max[local_id] = other_max; - } - } - barrier(); - } - - // Final result calculation (single workgroup only for reliability) - if (local_id == 0 && group_id == 0) { - float global_min = shared_min[0]; - float global_max = shared_max[0]; - - float scale_val; - int zero_point_val; - calc_scale_zp(global_min, global_max, quant_min, quant_max, 0, eps, scale_val, zero_point_val); - - write_texel(t_scale, ivec3(0, 0, 0), vec4(SCALE_OUT_T(scale_val), 0.0, 0.0, 0.0)); - write_texel(t_zero_point, ivec3(0, 0, 0), ivec4(ZP_OUT_T(zero_point_val), 0, 0, 0)); - } -} - -#elif defined(per_token) - -void choose_qparams_per_token() { - // Each token is processed by multiple workgroups for parallel reduction - uint local_id = gl_LocalInvocationID.x; - uint group_id = gl_WorkGroupID.x; - uint total_workgroups = gl_NumWorkGroups.x; - - uint total_texels = uint(t_in_limits.x * t_in_limits.y * t_in_limits.z); - - // Calculate texels per token (assuming last dimension contains the token data) - // For per-token quantization, we assume tokens are along the last dimension - uint texels_per_token = total_texels / uint(num_tokens); - - // Calculate how many tokens each workgroup should process - uint tokens_per_workgroup = (uint(num_tokens) + total_workgroups - 1) / total_workgroups; - - // Calculate which tokens this workgroup is responsible for - uint start_token = group_id * tokens_per_workgroup; - uint end_token = min(start_token + tokens_per_workgroup, uint(num_tokens)); - - // Process each token assigned to this workgroup - for (uint token_id = start_token; token_id < end_token; token_id++) { - // Calculate the texel range for this token - uint token_start_texel = token_id * texels_per_token; - uint token_end_texel = token_start_texel + texels_per_token; - - // Each thread processes multiple texels within the token - float thread_min = 1.0/0.0; // +infinity - float thread_max = -1.0/0.0; // -infinity - bool found_valid = false; - - // Process texels within this token only - for (uint texel_idx = token_start_texel + local_id; texel_idx < token_end_texel; texel_idx += gl_WorkGroupSize.x) { - // Convert linear texel index to 3D coordinates - uint z = texel_idx / uint(t_in_limits.x * t_in_limits.y); - uint remainder = texel_idx % uint(t_in_limits.x * t_in_limits.y); - uint y = remainder / uint(t_in_limits.x); - uint x = remainder % uint(t_in_limits.x); - ivec3 texel_pos = ivec3(int(x), int(y), int(z)); - - FVEC4_T texel_data = load_texel(t_in, texel_pos); - - // For texture storage, we assume width-packed (packed_dim = 0) - // Calculate number of valid elements in this texel (handle padding) - int packed_dim = 0; // Width dimension is packed - ivec4 sizes = ivec4(t_in_limits, 1); // Convert limits to sizes format - ivec4 tensor_coord = to_tensor_idx(texel_pos, sizes, packed_dim); - - // Calculate total tensor elements to determine padding - int total_elements = t_in_limits.x * t_in_limits.y * t_in_limits.z * 4; - int linear_tensor_idx = tensor_coord.x + tensor_coord.y * sizes.x + - tensor_coord.z * sizes.x * sizes.y; - int remaining_elements = total_elements - (linear_tensor_idx); - int valid_elements = min(4, remaining_elements); - - // Find min/max within this texel, considering only valid elements - if (valid_elements >= 1 && !isnan(texel_data.x) && !isinf(texel_data.x)) { - if (!found_valid) { - thread_min = texel_data.x; - thread_max = texel_data.x; - found_valid = true; - } else { - thread_min = min(thread_min, texel_data.x); - thread_max = max(thread_max, texel_data.x); - } - } - - if (valid_elements >= 2 && !isnan(texel_data.y) && !isinf(texel_data.y)) { - if (!found_valid) { - thread_min = texel_data.y; - thread_max = texel_data.y; - found_valid = true; - } else { - thread_min = min(thread_min, texel_data.y); - thread_max = max(thread_max, texel_data.y); - } - } - - if (valid_elements >= 3 && !isnan(texel_data.z) && !isinf(texel_data.z)) { - if (!found_valid) { - thread_min = texel_data.z; - thread_max = texel_data.z; - found_valid = true; - } else { - thread_min = min(thread_min, texel_data.z); - thread_max = max(thread_max, texel_data.z); - } - } - - if (valid_elements >= 4 && !isnan(texel_data.w) && !isinf(texel_data.w)) { - if (!found_valid) { - thread_min = texel_data.w; - thread_max = texel_data.w; - found_valid = true; - } else { - thread_min = min(thread_min, texel_data.w); - thread_max = max(thread_max, texel_data.w); - } - } - } - - // Intra-workgroup reduction using shared memory - shared_min[local_id] = thread_min; - shared_max[local_id] = thread_max; - barrier(); - - // Tree reduction within work group - for (uint stride = gl_WorkGroupSize.x / 2; stride > 0; stride >>= 1) { - if (local_id < stride) { - float other_min = shared_min[local_id + stride]; - float other_max = shared_max[local_id + stride]; - - // Handle infinity values properly - if (!isinf(other_min) && (isinf(shared_min[local_id]) || other_min < shared_min[local_id])) { - shared_min[local_id] = other_min; - } - if (!isinf(other_max) && (isinf(shared_max[local_id]) || other_max > shared_max[local_id])) { - shared_max[local_id] = other_max; - } - } - barrier(); - } - - // Final calculation for this token - if (local_id == 0) { - float token_min = shared_min[0]; - float token_max = shared_max[0]; - - float scale_val; - int zero_point_val; - calc_scale_zp(token_min, token_max, quant_min, quant_max, 0, 1e-5, scale_val, zero_point_val); - - // Convert token_id to 3D coordinates for output texture - // Assuming output tensors have the same layout as input but with different dimensions - uint out_z = token_id / uint(t_scale_limits.x * t_scale_limits.y); - uint out_remainder = token_id % uint(t_scale_limits.x * t_scale_limits.y); - uint out_y = out_remainder / uint(t_scale_limits.x); - uint out_x = out_remainder % uint(t_scale_limits.x); - ivec3 out_pos = ivec3(int(out_x), int(out_y), int(out_z)); - - write_texel(t_scale, out_pos, vec4(SCALE_OUT_T(scale_val), 0.0, 0.0, 0.0)); - write_texel(t_zero_point, out_pos, ivec4(ZP_OUT_T(zero_point_val), 0, 0, 0)); - } - - // Synchronize before processing next token - barrier(); - } -} - -#elif defined(block_wise) - -ivec4 block_id_to_coord(uint bid) { - ivec4 bc; - bc.w = int(bid) / blockStride.w; - - int r = int(bid) - bc.w * blockStride.w; - bc.z = r / blockStride.z; - - r -= bc.z * blockStride.z; - bc.y = r / blockStride.y; - - r -= bc.y * blockStride.y; - bc.x = r; - return bc; -} - -void choose_qparams_block_wise() { - const uint T = uint(numBlocks.x * numBlocks.y * numBlocks.z * numBlocks.w); - const uint STRIDE = gl_WorkGroupSize.x * gl_NumWorkGroups.x; - - // tensor full size in WHCN order - const ivec4 tensorSz = blockSize * numBlocks; - - // Process blocks with stride for better parallelization - for (uint blkIdx = gl_GlobalInvocationID.x; blkIdx < T; blkIdx += STRIDE) { - // block index in WHCN - const ivec4 b4d = block_id_to_coord(blkIdx); - const ivec4 blockStart = b4d * blockSize; - const ivec4 blockEnd = blockStart + blockSize; - - // scan all elements inside the block - float vmin = 3.402823e38; // +FLT_MAX - float vmax = -3.402823e38; // -FLT_MAX - bool found_valid = false; - - // Calculate total elements in block for linear iteration - const int blockElements = blockSize.x * blockSize.y * blockSize.z * blockSize.w; - - // Linear iteration over block elements (more cache-friendly) - for (int elemIdx = 0; elemIdx < blockElements; ++elemIdx) { - // Convert linear index to 4D coordinates within block - int remaining = elemIdx; - int dn = remaining / (blockSize.x * blockSize.y * blockSize.z); - remaining -= dn * (blockSize.x * blockSize.y * blockSize.z); - int dc = remaining / (blockSize.x * blockSize.y); - remaining -= dc * (blockSize.x * blockSize.y); - int dh = remaining / blockSize.x; - int dw = remaining - dh * blockSize.x; - - ivec4 tidx = blockStart + ivec4(dw, dh, dc, dn); - - // skip padding when tensor size is not an exact multiple of block - if (any(greaterThanEqual(tidx, tensorSz))) { continue; } - - // tensor index -> (x,y,z,component) inside input texture - ivec4 posi = to_texture_elem_pos(tidx, tensorSz, 0); // 0 = W_DIM (width packed) - - // fetch texel and pick the element inside it - FVEC4_T texl = load_texel(t_in, posi.xyz); - float v; - if (posi.w == 0) v = texl.x; - else if (posi.w == 1) v = texl.y; - else if (posi.w == 2) v = texl.z; - else v = texl.w; - - if (!isnan(v) && !isinf(v)) { - if (!found_valid) { - vmin = vmax = v; - found_valid = true; - } else { - vmin = min(vmin, v); - vmax = max(vmax, v); - } - } - } - - // Handle case where no valid values were found - if (!found_valid) { - vmin = 0.0; - vmax = 0.0; - } - - // compute scale / zero‑point (same maths as buffer kernel) - float scale; - int zp; - calc_scale_zp(vmin, vmax, quant_min, quant_max, mapping_type, eps, scale, zp); - - // Write the scalar values directly to buffer using linear index - t_scale[blkIdx] = SCALE_OUT_T(scale); - t_zero_point[blkIdx] = ZP_OUT_T(zp); - } -} - -#endif - -void main() { - choose_qparams_${MODE}(); -} diff --git a/backends/vulkan/runtime/graph/ops/glsl/choose_qparams_texture.yaml b/backends/vulkan/runtime/graph/ops/glsl/choose_qparams_texture.yaml deleted file mode 100644 index 12228822d4b..00000000000 --- a/backends/vulkan/runtime/graph/ops/glsl/choose_qparams_texture.yaml +++ /dev/null @@ -1,22 +0,0 @@ -choose_qparams_texture: - parameter_names_with_default_values: - IN_DTYPE: float - SCALE_OUT_DTYPE: float - ZP_OUT_DTYPE: int32 - MODE: per_tensor - generate_variant_forall: - IN_DTYPE: - - VALUE: float - SCALE_OUT_DTYPE: - - VALUE: float - ZP_OUT_DTYPE: - - VALUE: int32 - - VALUE: int8 - - VALUE: float - shader_variants: - - NAME: choose_qparams_tensor_texture3d - MODE: per_tensor - - NAME: choose_qparams_per_token_asymmetric_texture3d - MODE: per_token - - NAME: choose_qparams_block_wise_texture3d - MODE: block_wise diff --git a/backends/vulkan/runtime/graph/ops/glsl/dequantize.glslh b/backends/vulkan/runtime/graph/ops/glsl/dequantize.glslh deleted file mode 100644 index 7194bebda35..00000000000 --- a/backends/vulkan/runtime/graph/ops/glsl/dequantize.glslh +++ /dev/null @@ -1,16 +0,0 @@ -/* - * Copyright (c) Meta Platforms, Inc. and affiliates. - * All rights reserved. - * - * This source code is licensed under the BSD-style license found in the - * LICENSE file in the root directory of this source tree. - */ - -#ifndef DEQUANTIZE_GLSLH -#define DEQUANTIZE_GLSLH - -OUT_T dequantize_val(IN_T qvalue, float scale_val, int zero_point_val) { - return OUT_T(float(int(qvalue) - zero_point_val) * scale_val); -} - -#endif // DEQUANTIZE_GLSLH diff --git a/backends/vulkan/runtime/graph/ops/glsl/dequantize_buffer.glsl b/backends/vulkan/runtime/graph/ops/glsl/dequantize_buffer.glsl deleted file mode 100644 index 57dc2d53fff..00000000000 --- a/backends/vulkan/runtime/graph/ops/glsl/dequantize_buffer.glsl +++ /dev/null @@ -1,263 +0,0 @@ -/* - * Copyright (c) Meta Platforms, Inc. and affiliates. - * All rights reserved. - * - * This source code is licensed under the BSD-style license found in the - * LICENSE file in the root directory of this source tree. - */ - -#version 450 core - -#define PRECISION ${PRECISION} - -#define IN_T ${buffer_scalar_type(IN_DTYPE)} -#define OUT_T ${buffer_scalar_type(OUT_DTYPE)} -#define SCALE_T ${buffer_scalar_type(SCALE_DTYPE)} -#define ZP_T ${buffer_scalar_type(ZP_DTYPE)} - -#define ${MODE} - -${define_active_storage_type("buffer")} -${define_required_extensions(IN_DTYPE)} -${define_required_extensions(OUT_DTYPE)} -${define_required_extensions(SCALE_DTYPE)} -${define_required_extensions(ZP_DTYPE)} - -layout(std430) buffer; - -#include "indexing_utils.h" - -${layout_declare_tensor(B, "w", "t_out", OUT_DTYPE, "buffer")} -${layout_declare_tensor(B, "r", "t_in", IN_DTYPE, "buffer")} - -$if MODE == "per_tensor": - ${layout_declare_tensor(B, "r", "t_scale", SCALE_DTYPE, "buffer")} - ${layout_declare_tensor(B, "r", "t_zero_point", ZP_DTYPE, "buffer")} - - layout(push_constant) uniform restrict Block { - int quant_min; - int quant_max; - }; -$if MODE == "per_token": - ${layout_declare_tensor(B, "r", "t_scale", SCALE_DTYPE, "buffer")} - ${layout_declare_tensor(B, "r", "t_zero_point", ZP_DTYPE, "buffer")} - - layout(push_constant) uniform restrict Block { - int num_tokens; - int quant_min; - int quant_max; - }; -$if MODE == "per_channel": - ${layout_declare_tensor(B, "r", "t_scale", SCALE_DTYPE, "buffer")} - ${layout_declare_tensor(B, "r", "t_zero_point", ZP_DTYPE, "buffer")} - - layout(push_constant) uniform restrict Block { - int axis; - int num_channels; - int quant_min; - int quant_max; - }; -$if MODE == "block_wise": - ${layout_declare_tensor(B, "r", "t_scale", SCALE_DTYPE, "buffer")} - ${layout_declare_tensor(B, "r", "t_zero_point", ZP_DTYPE, "buffer")} - - layout(push_constant) uniform restrict Block { - ivec4 blockSize; // bW, bH, bC, bN - ivec4 numBlocks; // tW/bW, tH/bH, tC/bC, tN/bN - ivec4 blockStride; // pre-computed linear strides for the block grid - int quant_min; - int quant_max; - }; - -${layout_declare_ubo(B, "int", "out_numel")} -${layout_declare_ubo(B, "ivec4", "t_in_sizes")} -${layout_declare_ubo(B, "ivec4", "t_in_strides")} -${layout_declare_ubo(B, "ivec4", "t_out_sizes")} -${layout_declare_ubo(B, "ivec4", "t_out_strides")} - -${layout_declare_spec_const(C, "int", "out_layout", "DEFAULT_LAYOUT")} -${layout_declare_spec_const(C, "int", "in_layout", "DEFAULT_LAYOUT")} - -#include "dequantize.glslh" - -layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in; - -const lowp ivec4 out_dim_order = unhash_dim_order(out_layout); -const lowp ivec4 in_dim_order = unhash_dim_order(in_layout); - -/* - Dequantization Shader (Buffer Storage) - This shader converts n-bit integer tensor values back to floating-point representations - using pre-computed quantization parameters (scale and zero_point). The dequantization - reconstructs the original floating-point values from their discrete integer representations - with minimal precision loss. - - Important Considerations: - (+) All input tensors are assumed to be WIDTH_PACKED (i.e., contiguous in the last dimension) - (+) The axis map layout is assumed to be a standard layout for scales and zero_points - (++) The scale and zero_point tensors must be implemented as buffers - - Workgroup Configuration: - - dequantize_per_tensor - This mode reverses the uniform quantization applied across the entire tensor by using the - single scale and zero_point values to convert quantized integer values back to their original - floating-point representation. - - (*) global_wg_size: default - (*) local_wg_size: default - - - dequantize_per_token - This mode reverses the quantization applied individually to each token (or element) in the - input by using separate scale and zero_point values for each token. For a tensor of shape - [B, S, H], it applies the inverse transformation token-wise across the B*S tokens, converting - quantized values back to their original floating-point representation for each group of H - elements independently. - - (*) global_wg_size: default - (*) local_wg_size: default - - - dequantize_per_channel - This mode reverses the quantization applied separately to each channel of the input tensor - by using distinct scale and zero_point values for each channel. For a tensor of shape - [B, C, H, W] with axis = 1, it applies the inverse transformation channel-wise across the C - channels, converting quantized values back to their original floating-point representation - independently for each channel. - - (*) global_wg_size: default - (*) local_wg_size: default - - - dequantize_block_wise - This mode reverses the block-wise quantization applied to groups of elements by using separate - scale and zero_point values for each block. Equivalent to dequantize_affine, it applies the - inverse affine transformation per block to convert quantized values back to their original - floating-point representation. For example, if the tensor shape is [6, 9, 4] and - blockSize = [3, 3, 2], the tensor is divided into 12 blocks, each containing 18 elements, - and dequantization is performed independently on each block. - - (*) global_wg_size: default - (*) local_wg_size: default - - Dequantization Formula: - value = (qvalue - zero_point) * scale -*/ - -#ifdef per_tensor - -void dequantize_per_tensor() { - const int out_bufi = int(gl_GlobalInvocationID.x); - - if (out_bufi >= out_numel) { - return; - } - - const ivec4 out_tidx = bufi_to_tidx(out_bufi, t_out_strides, out_dim_order); - const int in_bufi = tidx_to_bufi(out_tidx, t_in_strides); - - IN_T qvalue = t_in[in_bufi]; - OUT_T value = dequantize_val(qvalue, float(t_scale[0]), int(t_zero_point[0])); - - t_out[out_bufi] = value; -} - -#elif defined(per_token) - -void dequantize_per_token() { - const int out_bufi = int(gl_GlobalInvocationID.x); - - if (out_bufi >= out_numel) { - return; - } - - const ivec4 out_tidx = bufi_to_tidx(out_bufi, t_out_strides, out_dim_order); - const int in_bufi = tidx_to_bufi(out_tidx, t_in_strides); - - IN_T qvalue = t_in[in_bufi]; - - int token_idx = 0; - - if (t_out_sizes.w > 1) { - // 4D tensor - token_idx = out_tidx.w * (t_out_sizes.z * t_out_sizes.y) + out_tidx.z * t_out_sizes.y + out_tidx.y; - } else if (t_out_sizes.z > 1) { - // 3D tensor - token_idx = out_tidx.z * t_out_sizes.y + out_tidx.y; - } else if (t_out_sizes.y > 1) { - // 2D tensor - token_idx = out_tidx.y; - } - // For 1D tensor, token_idx remains 0 - - token_idx = min(token_idx, num_tokens - 1); - - OUT_T value = dequantize_val(qvalue, float(t_scale[token_idx]), int(t_zero_point[token_idx])); - - t_out[out_bufi] = value; -} - -#elif defined(per_channel) - -void dequantize_per_channel() { - const int out_bufi = int(gl_GlobalInvocationID.x); - - if (out_bufi >= out_numel) { - return; - } - - const ivec4 out_tidx = bufi_to_tidx(out_bufi, t_out_strides, out_dim_order); - const int in_bufi = tidx_to_bufi(out_tidx, t_in_strides); - - IN_T qvalue = t_in[in_bufi]; - - // Calculate channel index based on the dequantization axis (already converted to WHCN) - // The axis parameter is now in WHCN coordinate system: - // axis 0 -> W dimension (tidx.x) - // axis 1 -> H dimension (tidx.y) - // axis 2 -> C dimension (tidx.z) - // axis 3 -> N dimension (tidx.w) - int channel_idx = 0; - - if (axis == 0) { - channel_idx = out_tidx.x; - } else if (axis == 1) { - channel_idx = out_tidx.y; - } else if (axis == 2) { - channel_idx = out_tidx.z; - } else if (axis == 3) { - channel_idx = out_tidx.w; - } - - channel_idx = min(channel_idx, num_channels - 1); - - OUT_T value = dequantize_val(qvalue, float(t_scale[channel_idx]), int(t_zero_point[channel_idx])); - - t_out[out_bufi] = value; -} - -#else // block_wise - -void dequantize_block_wise() { - const int out_bufi = int(gl_GlobalInvocationID.x); - - if (out_bufi >= out_numel) { - return; - } - - const ivec4 out_tidx = bufi_to_tidx(out_bufi, t_out_strides, out_dim_order); - const int in_bufi = tidx_to_bufi(out_tidx, t_in_strides); - - IN_T qvalue = t_in[in_bufi]; - - const ivec4 bcoord = out_tidx / blockSize; - - const int block_id = bcoord.x * blockStride.x + bcoord.y * blockStride.y + bcoord.z * blockStride.z + bcoord.w * blockStride.w; - - const OUT_T value = dequantize_val(qvalue, float(t_scale[block_id]), int(t_zero_point[block_id])); - - t_out[out_bufi] = value; -} - -#endif - -void main() { - dequantize_${MODE}(); -} diff --git a/backends/vulkan/runtime/graph/ops/glsl/dequantize_buffer.yaml b/backends/vulkan/runtime/graph/ops/glsl/dequantize_buffer.yaml deleted file mode 100644 index a4375038a75..00000000000 --- a/backends/vulkan/runtime/graph/ops/glsl/dequantize_buffer.yaml +++ /dev/null @@ -1,31 +0,0 @@ -dequantize_buffer: - parameter_names_with_default_values: - IN_DTYPE: int32 - OUT_DTYPE: float - SCALE_DTYPE: float - ZP_DTYPE: int32 - MODE: per_tensor - generate_variant_forall: - IN_DTYPE: - - VALUE: uint8 - - VALUE: int8 - - VALUE: int32 - OUT_DTYPE: - - VALUE: half - - VALUE: float - - VALUE: double - SCALE_DTYPE: - - VALUE: float - ZP_DTYPE: - - VALUE: int8 - - VALUE: int32 - - VALUE: float - shader_variants: - - NAME: dequantize_per_tensor_buffer - MODE: per_tensor - - NAME: dequantize_per_token_buffer - MODE: per_token - - NAME: dequantize_per_channel_buffer - MODE: per_channel - - NAME: dequantize_block_wise_buffer - MODE: block_wise diff --git a/backends/vulkan/runtime/graph/ops/glsl/dequantize_texture.glsl b/backends/vulkan/runtime/graph/ops/glsl/dequantize_texture.glsl deleted file mode 100644 index 19276cd8f7f..00000000000 --- a/backends/vulkan/runtime/graph/ops/glsl/dequantize_texture.glsl +++ /dev/null @@ -1,347 +0,0 @@ -/* - * Copyright (c) Meta Platforms, Inc. and affiliates. - * All rights reserved. - * - * This source code is licensed under the BSD-style license found in the - * LICENSE file in the root directory of this source tree. - */ - -#version 450 core - -#define PRECISION ${PRECISION} - -#define IN_T ${buffer_scalar_type(IN_DTYPE)} -#define IVEC4_T ${texel_load_type(IN_DTYPE, "texture3d")} - -#define OUT_T ${buffer_scalar_type(OUT_DTYPE)} -#define FVEC4_T ${texel_load_type(OUT_DTYPE, "texture3d")} -#define SCALE_T ${buffer_scalar_type(SCALE_DTYPE)} -#define ZP_T ${buffer_scalar_type(ZP_DTYPE)} - -#define ${MODE} - -${define_active_storage_type("texture3d")} -${define_required_extensions(IN_DTYPE)} -${define_required_extensions(OUT_DTYPE)} -${define_required_extensions(SCALE_DTYPE)} -${define_required_extensions(ZP_DTYPE)} - -#extension GL_EXT_control_flow_attributes : require - -layout(std430) buffer; - -${layout_declare_tensor(B, "w", "t_out", OUT_DTYPE, "texture3d")} -${layout_declare_tensor(B, "r", "t_in", IN_DTYPE, "texture3d")} - -$if MODE == "per_tensor": - ${layout_declare_tensor(B, "r", "t_scale", SCALE_DTYPE, "buffer")} - ${layout_declare_tensor(B, "r", "t_zero_point", ZP_DTYPE, "buffer")} - - layout(push_constant) uniform restrict Block { - int quant_min; - int quant_max; - }; -$if MODE == "per_token": - ${layout_declare_tensor(B, "r", "t_scale", SCALE_DTYPE, "buffer")} - ${layout_declare_tensor(B, "r", "t_zero_point", ZP_DTYPE, "buffer")} - - layout(push_constant) uniform restrict Block { - int num_tokens; - int quant_min; - int quant_max; - }; -$if MODE == "per_channel": - ${layout_declare_tensor(B, "r", "t_scale", SCALE_DTYPE, "buffer")} - ${layout_declare_tensor(B, "r", "t_zero_point", ZP_DTYPE, "buffer")} - - layout(push_constant) uniform restrict Block { - int axis; - int num_channels; - int quant_min; - int quant_max; - }; -$if MODE == "block_wise": - ${layout_declare_tensor(B, "r", "t_scale", SCALE_DTYPE, "buffer")} - ${layout_declare_tensor(B, "r", "t_zero_point", ZP_DTYPE, "buffer")} - - layout(push_constant) uniform restrict Block { - ivec4 blockSize; // bW, bH, bC, bN - ivec4 numBlocks; // tW/bW, tH/bH, tC/bC, tN/bN - ivec4 blockStride; // pre-computed linear strides for the block grid - int quant_min; - int quant_max; - }; - -${layout_declare_ubo(B, "ivec3", "t_in_limits")} -${layout_declare_ubo(B, "ivec3", "t_out_limits")} - -#include "indexing_utils.h" -#include "dequantize.glslh" - -layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in; - -/* - * DEQUANTIZATION SHADER (TEXTURE STORAGE) - * - * This shader converts n-bit integer tensor values back to floating-point representations - * using pre-computed quantization parameters (scale and zero_point). The dequantization - * reconstructs the original floating-point values from their discrete integer representations - * with minimal precision loss. - * - * ALGORITHM: - * 1. Load quantized integer texel (4 values) from 3D texture - * 2. Apply dequantization formula to each component: value = (qvalue - zero_point) * scale - * 3. Store reconstructed floating-point texel to output texture - * - * WORKGROUP CONFIGURATION: - * - Per-Tensor Mode: - * - Global WG Size: {W, H, C/4} for input size (W, H, C) with width-packing - * - Local WG Size: Default (typically {8, 8, 1} or based on global WG size) - * - Per-Token Mode: - * - Global WG Size: {W, H, C/4} for input size (W, H, C) with width-packing - * - Local WG Size: Default (typically {8, 8, 1} or based on global WG size) - * - * SUPPORTED CONFIGURATIONS: - * - Texture Storage: Uses 3D texture indexing with texel-based processing - * - Assumes width-packed layout (packed_dim = 0) for input/output textures - * - Handles texel padding for non-multiple-of-4 tensor dimensions - * - For per-token mode: scale/zero_point tensors must use buffer storage - * - Input/output textures: Must use standard axis mapping for per-token mode - * - * DEQUANTIZATION FORMULA VISUALIZATION: - * For integer range [quant_min, quant_max] mapped back to [min_val, max_val]: - * - * Integer Domain: Floating Point Domain: - * quant_min ──────────────► min_val - * │ │ - * │ scale = (max_val - min_val) / (quant_max - quant_min) - * │ zero_point = quant_min - round(min_val / scale) - * │ │ - * quant_max ──────────────► max_val - * - * Texel Dequantization Process: - * Input Texel: [-103, -128, -123, -96] (int4) - * Per-component dequantization with scale=0.1, zero_point=-128: - * Component 0: (-103 - (-128)) * 0.1 = 25 * 0.1 = 2.5 - * Component 1: (-128 - (-128)) * 0.1 = 0 * 0.1 = 0.0 - * Component 2: (-123 - (-128)) * 0.1 = 5 * 0.1 = 0.5 - * Component 3: (-96 - (-128)) * 0.1 = 32 * 0.1 = 3.2 - * Output Texel: [2.5, 0.0, 0.5, 3.2] (float4) - * - * PER-TENSOR DEQUANTIZATION: - * - Single scale and zero_point values for entire tensor - * - All texel components use same dequantization parameters - * - Parameters passed as push constants for efficiency - * - Each thread processes one texel (4 elements) independently - * - Formula: value[i] = (qvalue[i] - zero_point) * scale - * - * PER-TOKEN DEQUANTIZATION: - * - Separate scale and zero_point for each token - * - Token = all elements except last dimension (e.g., for [B,S,H]: B*S tokens of H elements) - * - Parameters stored in buffer arrays indexed by token_id - * - Each thread calculates token_id from its 3D texture position - * - Scale/zero_point buffers accessed directly (not as textures) - * - Formula: value[i] = (qvalue[i] - zero_point[token_id]) * scale[token_id] - * - * Token ID calculation for texel at position (x, y, z): - * - 3D tensor: token_id = z * texture_height + y - * - 2D tensor: token_id = y - * - 1D tensor: token_id = 0 - */ - -#ifdef per_tensor - -void dequantize_per_tensor() { - const ivec3 pos = ivec3(gl_GlobalInvocationID); - - // Skip if out of bounds - if (any(greaterThanEqual(pos, t_in_limits))) { - return; - } - - IVEC4_T intex = load_texel(t_in, pos); - FVEC4_T outtex; - - [[unroll]] for (int i = 0; i < 4; ++i) { - IN_T qvalue = IN_T(intex[i]); - OUT_T value = dequantize_val(qvalue, float(t_scale[0]), int(t_zero_point[0])); - - $if OUT_DTYPE == "double": - outtex[i] = float(value); - $else: - outtex[i] = value; - } - write_texel(t_out, pos, outtex); -} - -#elif defined(per_token) - -void dequantize_per_token() { - const ivec3 pos = ivec3(gl_GlobalInvocationID); - - if (any(greaterThanEqual(pos, t_in_limits))) { - return; - } - - IVEC4_T intex = load_texel(t_in, pos); - - int token_idx = 0; - ivec3 dims = t_in_limits; - - if (dims.z > 1) { - // 3D tensor - token_idx = pos.z * dims.y + pos.y; - } else if (dims.y > 1) { - // 2D tensor - token_idx = pos.y; - } - // For 1D tensor, token_idx remains 0 - - token_idx = min(token_idx, num_tokens - 1); - - // Scale and zero_point are prepacked as buffers, so direct access - float scale_val = float(t_scale[token_idx]); - int zero_point_val = int(t_zero_point[token_idx]); - - FVEC4_T outtex; - [[unroll]] for (int i = 0; i < 4; ++i) { - IN_T qvalue = IN_T(intex[i]); - OUT_T value = dequantize_val(qvalue, scale_val, zero_point_val); - $if OUT_DTYPE == "double": - outtex[i] = float(value); - $else: - outtex[i] = value; - } - - write_texel(t_out, pos, outtex); -} - -#elif defined(per_channel) - -void dequantize_per_channel() { - const ivec3 pos = ivec3(gl_GlobalInvocationID); - - if (any(greaterThanEqual(pos, t_in_limits))) { - return; - } - - IVEC4_T intex = load_texel(t_in, pos); - FVEC4_T outtex; - - // Calculate channel index based on the dequantization axis (already converted to WHCN) - // The axis parameter is now in WHCN coordinate system: - // axis 0 -> W dimension (pos.x) - // axis 1 -> H dimension (pos.y) - // axis 2 -> C dimension (pos.z) - // axis 3 -> N dimension (batch folding in texture storage) - - if (axis == 0) { - // Width dimension - each texel component has different channel index - [[unroll]] for (int i = 0; i < 4; ++i) { - IN_T qvalue = IN_T(intex[i]); - int channel_idx = pos.x * 4 + i; - channel_idx = min(channel_idx, num_channels - 1); - - float scale_val = float(t_scale[channel_idx]); - int zero_point_val = int(t_zero_point[channel_idx]); - OUT_T value = dequantize_val(qvalue, scale_val, zero_point_val); - $if OUT_DTYPE == "double": - outtex[i] = float(value); - $else: - outtex[i] = value; - } - } else if (axis == 1) { - int channel_idx = pos.y; - channel_idx = min(channel_idx, num_channels - 1); - float scale_val = float(t_scale[channel_idx]); - int zero_point_val = int(t_zero_point[channel_idx]); - - [[unroll]] for (int i = 0; i < 4; ++i) { - IN_T qvalue = IN_T(intex[i]); - OUT_T value = dequantize_val(qvalue, scale_val, zero_point_val); - $if OUT_DTYPE == "double": - outtex[i] = float(value); - $else: - outtex[i] = value; - } - } else if (axis == 2) { - // Channel dimension - for 4D tensors, need to account for batch-channel folding - // The Z coordinate contains folded batch*channel information - // We need to extract the actual channel index from the folded dimension - int folded_idx = pos.z; - int channel_idx = folded_idx % num_channels; - - float scale_val = float(t_scale[channel_idx]); - int zero_point_val = int(t_zero_point[channel_idx]); - - [[unroll]] for (int i = 0; i < 4; ++i) { - IN_T qvalue = IN_T(intex[i]); - OUT_T value = dequantize_val(qvalue, scale_val, zero_point_val); - $if OUT_DTYPE == "double": - outtex[i] = float(value); - $else: - outtex[i] = value; - } - } else if (axis == 3) { - // Batch dimension - for 4D tensors, need to account for batch-channel folding - // The Z coordinate contains folded batch*channel information - // We need to extract the actual channel index from the folded dimension - int folded_idx = pos.z; - // In this case num_channels actually corresponds to the number of channels - // the C dimension N(C)HW - int channel_idx = folded_idx / num_channels; - - float scale_val = float(t_scale[channel_idx]); - int zero_point_val = int(t_zero_point[channel_idx]); - - [[unroll]] for (int i = 0; i < 4; ++i) { - IN_T qvalue = IN_T(intex[i]); - OUT_T value = dequantize_val(qvalue, scale_val, zero_point_val); - $if OUT_DTYPE == "double": - outtex[i] = float(value); - $else: - outtex[i] = value; - } - } - - write_texel(t_out, pos, outtex); -} - -#else // block_wise - -void dequantize_block_wise() { - const ivec3 pos = ivec3(gl_GlobalInvocationID); - - if (any(greaterThanEqual(pos, t_in_limits))) - return; - - IVEC4_T intex = load_texel(t_in, pos); - FVEC4_T outtex; - - ivec4 base_tidx = ivec4(pos.x * 4, pos.y, pos.z, 0); - int foldedZ = pos.z; - - int C_total = numBlocks.z * blockSize.z; - - [[unroll]] for (int i = 0; i < 4; ++i) { - ivec4 tidx = ivec4(base_tidx.x + i, base_tidx.y, (foldedZ % C_total), (foldedZ / C_total)); - - ivec4 bcoord = tidx / blockSize; - int block_id = bcoord.x * blockStride.x + bcoord.y * blockStride.y + bcoord.z * blockStride.z + bcoord.w * blockStride.w; - - IN_T qvalue = IN_T(intex[i]); - OUT_T value = dequantize_val(qvalue, float(t_scale[block_id]), int(t_zero_point[block_id])); - $if OUT_DTYPE == "double": - outtex[i] = float(value); - $else: - outtex[i] = value; - } - - write_texel(t_out, pos, outtex); -} - -#endif - -void main() { - dequantize_${MODE}(); -} diff --git a/backends/vulkan/runtime/graph/ops/glsl/dequantize_texture.yaml b/backends/vulkan/runtime/graph/ops/glsl/dequantize_texture.yaml deleted file mode 100644 index 7a58e9410d3..00000000000 --- a/backends/vulkan/runtime/graph/ops/glsl/dequantize_texture.yaml +++ /dev/null @@ -1,31 +0,0 @@ -dequantize_texture: - parameter_names_with_default_values: - IN_DTYPE: int32 - OUT_DTYPE: float - SCALE_DTYPE: float - ZP_DTYPE: int32 - MODE: per_tensor - generate_variant_forall: - IN_DTYPE: - - VALUE: uint8 - - VALUE: int8 - - VALUE: int32 - OUT_DTYPE: - - VALUE: half - - VALUE: float - - VALUE: double - SCALE_DTYPE: - - VALUE: float - ZP_DTYPE: - - VALUE: int8 - - VALUE: int32 - - VALUE: float - shader_variants: - - NAME: dequantize_per_tensor_texture3d - MODE: per_tensor - - NAME: dequantize_per_token_texture3d - MODE: per_token - - NAME: dequantize_per_channel_texture3d - MODE: per_channel - - NAME: dequantize_block_wise_texture3d - MODE: block_wise diff --git a/backends/vulkan/runtime/graph/ops/glsl/quantize_buffer.glsl b/backends/vulkan/runtime/graph/ops/glsl/quantize_buffer.glsl deleted file mode 100644 index 7bf3a932c6c..00000000000 --- a/backends/vulkan/runtime/graph/ops/glsl/quantize_buffer.glsl +++ /dev/null @@ -1,257 +0,0 @@ -/* - * Copyright (c) Meta Platforms, Inc. and affiliates. - * All rights reserved. - * - * This source code is licensed under the BSD-style license found in the - * LICENSE file in the root directory of this source tree. - */ - -#version 450 core - -#define PRECISION ${PRECISION} - -#define IN_T ${buffer_scalar_type(IN_DTYPE)} -#define OUT_T ${buffer_scalar_type(OUT_DTYPE)} -#define SCALE_T ${buffer_scalar_type(SCALE_DTYPE)} -#define ZP_T ${buffer_scalar_type(ZP_DTYPE)} - -#define ${MODE} - -${define_active_storage_type("buffer")} -${define_required_extensions(IN_DTYPE)} -${define_required_extensions(OUT_DTYPE)} -${define_required_extensions(SCALE_DTYPE)} -${define_required_extensions(ZP_DTYPE)} - -layout(std430) buffer; - -#include "indexing_utils.h" - -${layout_declare_tensor(B, "w", "t_out", OUT_DTYPE, "buffer")} -${layout_declare_tensor(B, "r", "t_in", IN_DTYPE, "buffer")} - -$if MODE == "per_tensor": - ${layout_declare_tensor(B, "r", "t_scale", SCALE_DTYPE, "buffer")} - ${layout_declare_tensor(B, "r", "t_zero_point", ZP_DTYPE, "buffer")} - - layout(push_constant) uniform restrict Block { - int quant_min; - int quant_max; - }; -$if MODE == "per_token": - ${layout_declare_tensor(B, "r", "t_scale", SCALE_DTYPE, "buffer")} - ${layout_declare_tensor(B, "r", "t_zero_point", ZP_DTYPE, "buffer")} - - layout(push_constant) uniform restrict Block { - int num_tokens; - int quant_min; - int quant_max; - }; -$if MODE == "per_channel": - ${layout_declare_tensor(B, "r", "t_scale", SCALE_DTYPE, "buffer")} - ${layout_declare_tensor(B, "r", "t_zero_point", ZP_DTYPE, "buffer")} - - layout(push_constant) uniform restrict Block { - int axis; - int num_channels; - int quant_min; - int quant_max; - }; -$if MODE == "block_wise": - ${layout_declare_tensor(B, "r", "t_scale", SCALE_DTYPE, "buffer")} - ${layout_declare_tensor(B, "r", "t_zero_point", ZP_DTYPE, "buffer")} - - layout(push_constant) uniform restrict Block { - ivec4 blockSize; // bW, bH, bC, bN - ivec4 numBlocks; // tW/bW, tH/bH, tC/bC, tN/bN - ivec4 blockStride; // pre-computed linear strides for the block grid - int quant_min; - int quant_max; - }; - -${layout_declare_ubo(B, "int", "out_numel")} -${layout_declare_ubo(B, "ivec4", "t_in_sizes")} -${layout_declare_ubo(B, "ivec4", "t_in_strides")} -${layout_declare_ubo(B, "ivec4", "t_out_sizes")} -${layout_declare_ubo(B, "ivec4", "t_out_strides")} - -${layout_declare_spec_const(C, "int", "out_layout", "DEFAULT_LAYOUT")} -${layout_declare_spec_const(C, "int", "in_layout", "DEFAULT_LAYOUT")} - -#include "quantize.glslh" - -layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in; - -const lowp ivec4 out_dim_order = unhash_dim_order(out_layout); -const lowp ivec4 in_dim_order = unhash_dim_order(in_layout); - -/* - Quantization Shader (Buffer Storage) - This shader converts floating-point tensor values to n-bit integer representations - using pre-computed quantization parameters (scale and zero_point). The quantization - maps floating-point values to a discrete integer range while preserving the original - data distribution as much as possible. - - Important Considerations: - (+) All input tensors are assumed to be WIDTH_PACKED (i.e., contiguous in the last dimension) - (+) The axis map layout is assumed to be a standard layout for scales and zero_points - (++) The scale and zero_point tensors must be implemented as buffers - - Workgroup Configuration: - - quantize_per_tensor - This mode applies uniform quantization across the entire tensor using a single scale - and zero_point value. - - (*) global_wg_size: default - (*) local_wg_size: default - - - quantize_per_token - This mode applies quantization individually to each token (or element) in the input, - using separate scale and zero_point values for each token. For instance if we have - a tensor of shape [B, S, H] then we have B*S tokens (and s+zp pairs) of H elements each. - - (*) global_wg_size: default - (*) local_wg_size: default - - - quantize_per_channel - This mode applies quantization separately to each channel of the input tensor, using - distinct scale and zero_point values for each channel. For example, if the tensor shape - is [B, C, H, W] and axis = 1, quantization parameters are computed per channel C, allowing - each channel to be quantized independently. - - (*) global_wg_size: default - (*) local_wg_size: default - - - quantize_block_wise - This mode applies quantization in blocks or groups of elements, allowing different scale - and zero_point values for each block. It is equivalent to quantize_affine, where quantization - parameters are affine transformations applied per block. For example, if the tensor shape - is [6, 9, 4] and blockSize = [3, 3, 2], then we have 12 blocks each with 18 elements. - - (*) global_wg_size: default - (*) local_wg_size: default - - Quantization Formula: - qvalue = clamp(round(value / scale) + zero_point, quant_min, quant_max). -*/ - -#ifdef per_tensor - -void quantize_per_tensor() { - const int out_bufi = int(gl_GlobalInvocationID.x); - - if (out_bufi >= out_numel) { - return; - } - - const ivec4 out_tidx = bufi_to_tidx(out_bufi, t_out_strides, out_dim_order); - const int in_bufi = tidx_to_bufi(out_tidx, t_in_strides); - - IN_T value = t_in[in_bufi]; - OUT_T qvalue = quantize_val(value, float(t_scale[0]), int(t_zero_point[0])); - - t_out[out_bufi] = qvalue; -} - -#elif defined(per_token) - -void quantize_per_token() { - const int out_bufi = int(gl_GlobalInvocationID.x); - - if (out_bufi >= out_numel) { - return; - } - - const ivec4 out_tidx = bufi_to_tidx(out_bufi, t_out_strides, out_dim_order); - const int in_bufi = tidx_to_bufi(out_tidx, t_in_strides); - - IN_T value = t_in[in_bufi]; - - int token_idx = 0; - - if (t_out_sizes.w > 1) { - // 4D tensor - token_idx = out_tidx.w * (t_out_sizes.z * t_out_sizes.y) + out_tidx.z * t_out_sizes.y + out_tidx.y; - } else if (t_out_sizes.z > 1) { - // 3D tensor - token_idx = out_tidx.z * t_out_sizes.y + out_tidx.y; - } else if (t_out_sizes.y > 1) { - // 2D tensor - token_idx = out_tidx.y; - } - // For 1D tensor, token_idx remains 0 - - token_idx = min(token_idx, num_tokens - 1); - - OUT_T qvalue = quantize_val(value, float(t_scale[token_idx]), int(t_zero_point[token_idx])); - - t_out[out_bufi] = qvalue; -} - -#elif defined(per_channel) - -void quantize_per_channel() { - const int out_bufi = int(gl_GlobalInvocationID.x); - - if (out_bufi >= out_numel) { - return; - } - - const ivec4 out_tidx = bufi_to_tidx(out_bufi, t_out_strides, out_dim_order); - const int in_bufi = tidx_to_bufi(out_tidx, t_in_strides); - - IN_T value = t_in[in_bufi]; - - // Calculate channel index based on the quantization axis (already converted to WHCN) - // The axis parameter is now in WHCN coordinate system: - // axis 0 -> W dimension (tidx.x) - // axis 1 -> H dimension (tidx.y) - // axis 2 -> C dimension (tidx.z) - // axis 3 -> N dimension (tidx.w) - int channel_idx = 0; - - if (axis == 0) { - channel_idx = out_tidx.x; - } else if (axis == 1) { - channel_idx = out_tidx.y; - } else if (axis == 2) { - channel_idx = out_tidx.z; - } else if (axis == 3) { - channel_idx = out_tidx.w; - } - - channel_idx = min(channel_idx, num_channels - 1); - - OUT_T qvalue = quantize_val(value, float(t_scale[channel_idx]), int(t_zero_point[channel_idx])); - - t_out[out_bufi] = qvalue; -} - -#else // block_wise - -void quantize_block_wise() { - const int out_bufi = int(gl_GlobalInvocationID.x); - - if (out_bufi >= out_numel) { - return; - } - - const ivec4 out_tidx = bufi_to_tidx(out_bufi, t_out_strides, out_dim_order); - const int in_bufi = tidx_to_bufi(out_tidx, t_in_strides); - - IN_T value = t_in[in_bufi]; - - const ivec4 bcoord = out_tidx / blockSize; - - const int block_id = bcoord.x * blockStride.x + bcoord.y * blockStride.y + bcoord.z * blockStride.z + bcoord.w * blockStride.w; - - const OUT_T qvalue = quantize_val(value, float(t_scale[block_id]), int(t_zero_point[block_id])); - - t_out[out_bufi] = qvalue; -} - -#endif - -void main() { - quantize_${MODE}(); -} diff --git a/backends/vulkan/runtime/graph/ops/glsl/quantize_buffer.yaml b/backends/vulkan/runtime/graph/ops/glsl/quantize_buffer.yaml deleted file mode 100644 index fb5853ecd20..00000000000 --- a/backends/vulkan/runtime/graph/ops/glsl/quantize_buffer.yaml +++ /dev/null @@ -1,31 +0,0 @@ -quantize_buffer: - parameter_names_with_default_values: - IN_DTYPE: float - OUT_DTYPE: int32 - SCALE_DTYPE: float - ZP_DTYPE: int32 - MODE: per_tensor - generate_variant_forall: - IN_DTYPE: - - VALUE: half - - VALUE: float - - VALUE: double - OUT_DTYPE: - - VALUE: uint8 - - VALUE: int8 - - VALUE: int32 - SCALE_DTYPE: - - VALUE: float - ZP_DTYPE: - - VALUE: int8 - - VALUE: int32 - - VALUE: float - shader_variants: - - NAME: quantize_per_tensor_buffer - MODE: per_tensor - - NAME: quantize_per_token_buffer - MODE: per_token - - NAME: quantize_per_channel_buffer - MODE: per_channel - - NAME: quantize_block_wise_buffer - MODE: block_wise diff --git a/backends/vulkan/runtime/graph/ops/glsl/quantize_texture.glsl b/backends/vulkan/runtime/graph/ops/glsl/quantize_texture.glsl deleted file mode 100644 index 12e5769f50d..00000000000 --- a/backends/vulkan/runtime/graph/ops/glsl/quantize_texture.glsl +++ /dev/null @@ -1,312 +0,0 @@ -/* - * Copyright (c) Meta Platforms, Inc. and affiliates. - * All rights reserved. - * - * This source code is licensed under the BSD-style license found in the - * LICENSE file in the root directory of this source tree. - */ - -#version 450 core - -#define PRECISION ${PRECISION} - -#define IN_T ${buffer_scalar_type(IN_DTYPE)} -#define FVEC4_T ${texel_load_type(IN_DTYPE, "texture3d")} - -#define OUT_T ${buffer_scalar_type(OUT_DTYPE)} -#define IVEC4_T ${texel_load_type(OUT_DTYPE, "texture3d")} -#define SCALE_T ${buffer_scalar_type(SCALE_DTYPE)} -#define ZP_T ${buffer_scalar_type(ZP_DTYPE)} - -#define ${MODE} - -${define_active_storage_type("texture3d")} -${define_required_extensions(IN_DTYPE)} -${define_required_extensions(OUT_DTYPE)} -${define_required_extensions(SCALE_DTYPE)} -${define_required_extensions(ZP_DTYPE)} - -#extension GL_EXT_control_flow_attributes : require - -layout(std430) buffer; - -#include "indexing_utils.h" - -${layout_declare_tensor(B, "w", "t_out", OUT_DTYPE, "texture3d")} -${layout_declare_tensor(B, "r", "t_in", IN_DTYPE, "texture3d")} - -$if MODE == "per_tensor": - ${layout_declare_tensor(B, "r", "t_scale", SCALE_DTYPE, "buffer")} - ${layout_declare_tensor(B, "r", "t_zero_point", ZP_DTYPE, "buffer")} - - layout(push_constant) uniform restrict Block { - int quant_min; - int quant_max; - }; -$if MODE == "per_token": - ${layout_declare_tensor(B, "r", "t_scale", SCALE_DTYPE, "buffer")} - ${layout_declare_tensor(B, "r", "t_zero_point", ZP_DTYPE, "buffer")} - - layout(push_constant) uniform restrict Block { - int num_tokens; - int quant_min; - int quant_max; - }; -$if MODE == "per_channel": - ${layout_declare_tensor(B, "r", "t_scale", SCALE_DTYPE, "buffer")} - ${layout_declare_tensor(B, "r", "t_zero_point", ZP_DTYPE, "buffer")} - - layout(push_constant) uniform restrict Block { - int axis; - int num_channels; - int quant_min; - int quant_max; - }; -$if MODE == "block_wise": - ${layout_declare_tensor(B, "r", "t_scale", SCALE_DTYPE, "buffer")} - ${layout_declare_tensor(B, "r", "t_zero_point", ZP_DTYPE, "buffer")} - - layout(push_constant) uniform restrict BlockPC { - ivec4 blockSize; // WHCN - ivec4 numBlocks; // (#W,#H,#C,#N) - ivec4 blockStride; // {1, #W, #W * #H, #W * #H * #C} - int quant_min; - int quant_max; - }; - -${layout_declare_ubo(B, "ivec3", "t_in_limits")} -${layout_declare_ubo(B, "ivec3", "t_out_limits")} - -${layout_declare_spec_const(C, "int", "out_layout", "DEFAULT_LAYOUT")} -${layout_declare_spec_const(C, "int", "in_layout", "DEFAULT_LAYOUT")} - -#include "quantize.glslh" - -layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in; - -/* - Quantization Shader (Texture Storage) - This shader converts floating-point tensor values to n-bit integer representations - using pre-computed quantization parameters (scale and zero_point). The quantization - maps floating-point values to a discrete integer range while preserving the original - data distribution as much as possible. - - Important Considerations: - (+) All input tensors are assumed to be WIDTH_PACKED (i.e., contiguous in the last dimension) - (+) The axis map layout is assumed to be a standard layout for scales and zero_points - (++) The scale and zero_point tensors must be implemented as buffers - - Workgroup Configuration: - - quantize_per_tensor - This mode applies uniform quantization across the entire tensor using a single scale - and zero_point value. - - (*) global_wg_size: default - (*) local_wg_size: default - - - quantize_per_token - This mode applies quantization individually to each token (or element) in the input, - using separate scale and zero_point values for each token. For instance if we have - a tensor of shape [B, S, H] then we have B*S tokens (and s+zp pairs) of H elements each. - - (*) global_wg_size: default - (*) local_wg_size: default - - - quantize_per_channel - This mode applies quantization separately to each channel of the input tensor, using - distinct scale and zero_point values for each channel. For example, if the tensor shape - is [B, C, H, W] and axis = 1, quantization parameters are computed per channel C, allowing - each channel to be quantized independently. - - (*) global_wg_size: default - (*) local_wg_size: Default with special handling for batch dimension. When quantizing along - the batch axis, Z dimension is set to 1 to ensure correct workgroup dispatching. Otherwise, - uses standard workgroup size derived from global workgroup dimensions. - - - quantize_block_wise - This mode applies quantization in blocks or groups of elements, allowing different scale - and zero_point values for each block. It is equivalent to quantize_affine, where quantization - parameters are affine transformations applied per block. For example, if the tensor shape - is [6, 9, 4] and blockSize = [3, 3, 2], then we have 12 blocks each with 18 elements. - - (*) global_wg_size: default - (*) local_wg_size: Default with special handling for batch dimension. When quantizing along - the batch axis, Z dimension is set to 1 to ensure correct workgroup dispatching. Otherwise, - uses standard workgroup size derived from global workgroup dimensions. - - Quantization Formula: - qvalue = clamp(round(value / scale) + zero_point, quant_min, quant_max). -*/ - -#ifdef per_tensor - -void quantize_per_tensor() { - const ivec3 pos = ivec3(gl_GlobalInvocationID); - - if (any(greaterThanEqual(pos, t_in_limits))) { - return; - } - - FVEC4_T intex = load_texel(t_in, pos); - IVEC4_T outtex; - - [[unroll]] for (int i = 0; i < 4; ++i) { - IN_T value = IN_T(intex[i]); - OUT_T qvalue = quantize_val(value, float(t_scale[0]), int(t_zero_point[0])); - outtex[i] = qvalue; - } - write_texel(t_out, pos, outtex); -} - -#elif defined(per_token) - -void quantize_per_token() { - const ivec3 pos = ivec3(gl_GlobalInvocationID); - - if (any(greaterThanEqual(pos, t_in_limits))) { - return; - } - - FVEC4_T intex = load_texel(t_in, pos); - - int token_idx = 0; - ivec3 dims = t_in_limits; - - if (dims.z > 1) { - // 3D tensor - token_idx = pos.z * dims.y + pos.y; - } else if (dims.y > 1) { - // 2D tensor - token_idx = pos.y; - } - // For 1D tensor, token_idx remains 0 - - token_idx = min(token_idx, num_tokens - 1); - - // Scale and zero_point are prepacked as buffers, so direct access - float scale_val = float(t_scale[token_idx]); - int zero_point_val = int(t_zero_point[token_idx]); - - IVEC4_T outtex; - [[unroll]] for (int i = 0; i < 4; ++i) { - IN_T value = IN_T(intex[i]); - OUT_T qvalue = quantize_val(value, scale_val, zero_point_val); - outtex[i] = qvalue; - } - - write_texel(t_out, pos, outtex); -} - -#elif defined(per_channel) - -void quantize_per_channel() { - const ivec3 pos = ivec3(gl_GlobalInvocationID); - - if (any(greaterThanEqual(pos, t_in_limits))) { - return; - } - - FVEC4_T intex = load_texel(t_in, pos); - IVEC4_T outtex; - - // Calculate channel index based on the quantization axis (already converted to WHCN) - // The axis parameter is now in WHCN coordinate system: - // axis 0 -> W dimension (pos.x for texture, but width-packed so pos.x * 4 + component) - // axis 1 -> H dimension (pos.y) - // axis 2 -> C dimension (pos.z / C), but for 4D tensors this includes batch-channel folding - // axis 3 -> N dimension (pos.z / N), but for 4D tensors this includes batch-channel folding - - if (axis == 0) { - // Width dimension - each texel component has different channel index - [[unroll]] for (int i = 0; i < 4; ++i) { - IN_T value = IN_T(intex[i]); - int channel_idx = pos.x * 4 + i; - channel_idx = min(channel_idx, num_channels - 1); - - float scale_val = float(t_scale[channel_idx]); - int zero_point_val = int(t_zero_point[channel_idx]); - OUT_T qvalue = quantize_val(value, scale_val, zero_point_val); - outtex[i] = qvalue; - } - } else if (axis == 1) { - // Height dimension - all texel components use same channel index - int channel_idx = pos.y; - channel_idx = min(channel_idx, num_channels - 1); - float scale_val = float(t_scale[channel_idx]); - int zero_point_val = int(t_zero_point[channel_idx]); - - [[unroll]] for (int i = 0; i < 4; ++i) { - IN_T value = IN_T(intex[i]); - OUT_T qvalue = quantize_val(value, scale_val, zero_point_val); - outtex[i] = qvalue; - } - } else if (axis == 2) { - // Channel dimension - for 4D tensors, need to account for batch-channel folding - // The Z coordinate contains folded batch*channel information - // We need to extract the actual channel index from the folded dimension - int folded_idx = pos.z; - int channel_idx = folded_idx % num_channels; - - float scale_val = float(t_scale[channel_idx]); - int zero_point_val = int(t_zero_point[channel_idx]); - - [[unroll]] for (int i = 0; i < 4; ++i) { - IN_T value = IN_T(intex[i]); - OUT_T qvalue = quantize_val(value, scale_val, zero_point_val); - outtex[i] = qvalue; - } - } else if (axis == 3) { - // Batch dimension - for 4D tensors, need to account for batch-channel folding - // The Z coordinate contains folded batch*channel information - // We need to extract the actual batch index from the folded dimension - int folded_idx = pos.z; - int batch_idx = folded_idx / num_channels; - - float scale_val = float(t_scale[batch_idx]); - int zero_point_val = int(t_zero_point[batch_idx]); - - [[unroll]] for (int i = 0; i < 4; ++i) { - IN_T value = IN_T(intex[i]); - OUT_T qvalue = quantize_val(value, scale_val, zero_point_val); - outtex[i] = qvalue; - } - } - - write_texel(t_out, pos, outtex); -} - -#else // block_wise - -void quantize_block_wise() { - const ivec3 pos = ivec3(gl_GlobalInvocationID); - - if (any(greaterThanEqual(pos, t_in_limits))) - return; - - FVEC4_T intex = load_texel(t_in, pos); - IVEC4_T outtex; - - ivec4 base_tidx = ivec4(pos.x * 4, pos.y, pos.z, 0); - int foldedZ = pos.z; - - int C_total = numBlocks.z * blockSize.z; - - [[unroll]] for (int i = 0; i < 4; ++i) { - ivec4 tidx = ivec4(base_tidx.x + i, base_tidx.y, (foldedZ % C_total), (foldedZ / C_total)); - - ivec4 bcoord = tidx / blockSize; - int block_id = bcoord.x * blockStride.x + bcoord.y * blockStride.y + bcoord.z * blockStride.z + bcoord.w * blockStride.w; - - IN_T value = IN_T(intex[i]); - OUT_T qvalue = quantize_val(value, float(t_scale[block_id]), int(t_zero_point[block_id])); - outtex[i] = qvalue; - } - - write_texel(t_out, pos, outtex); -} - -#endif - -void main() { - quantize_${MODE}(); -} diff --git a/backends/vulkan/runtime/graph/ops/glsl/quantize_texture.yaml b/backends/vulkan/runtime/graph/ops/glsl/quantize_texture.yaml deleted file mode 100644 index 03d418ff2f7..00000000000 --- a/backends/vulkan/runtime/graph/ops/glsl/quantize_texture.yaml +++ /dev/null @@ -1,31 +0,0 @@ -quantize_texture: - parameter_names_with_default_values: - IN_DTYPE: float - OUT_DTYPE: int32 - SCALE_DTYPE: float - ZP_DTYPE: int32 - MODE: per_tensor - generate_variant_forall: - IN_DTYPE: - - VALUE: half - - VALUE: float - - VALUE: double - OUT_DTYPE: - - VALUE: uint8 - - VALUE: int8 - - VALUE: int32 - SCALE_DTYPE: - - VALUE: float - ZP_DTYPE: - - VALUE: int8 - - VALUE: int32 - - VALUE: float - shader_variants: - - NAME: quantize_per_tensor_texture3d - MODE: per_tensor - - NAME: quantize_per_token_texture3d - MODE: per_token - - NAME: quantize_per_channel_texture3d - MODE: per_channel - - NAME: quantize_block_wise_texture3d - MODE: block_wise diff --git a/backends/vulkan/runtime/graph/ops/impl/ChooseQParams.cpp b/backends/vulkan/runtime/graph/ops/impl/ChooseQParams.cpp index a4a96ffdb88..a36660e0aca 100644 --- a/backends/vulkan/runtime/graph/ops/impl/ChooseQParams.cpp +++ b/backends/vulkan/runtime/graph/ops/impl/ChooseQParams.cpp @@ -34,150 +34,6 @@ void resize_choose_qparams_per_row( graph->virtual_resize(input_zeros, new_sizes); } -utils::uvec3 choose_qparams_pick_global_wg_size( - ComputeGraph* graph, - const vkapi::ShaderInfo& shader, - const std::vector& args, - const std::vector& resize_args) { - (void)shader; - (void)resize_args; - - // For per-tensor quantization, we want a single workgroup that can handle - // all elements with proper reduction. The shader uses NWORKERS=64 threads. - const ValueRef input = args.at(1).refs.at(0); - - if (graph->is_buffer_storage(input)) { - // For buffer storage, use a single workgroup in X dimension - // The shader will handle strided access across all elements - return {1u, 1u, 1u}; - } else { - // For texture storage, use the default logic - return graph->create_global_wg_size(args.at(0).refs.at(0)); - } -} - -utils::uvec3 choose_qparams_pick_local_wg_size( - ComputeGraph* graph, - const vkapi::ShaderInfo& shader, - const utils::uvec3& global_workgroup_size, - const std::vector& args, - const std::vector& resize_args) { - (void)shader; - (void)resize_args; - - const ValueRef input = args.at(1).refs.at(0); - - if (graph->is_buffer_storage(input)) { - // For buffer storage, use 64 threads in X dimension to match NWORKERS - // This ensures the shared memory arrays are properly sized - return {64u, 1u, 1u}; - } else { - // For texture storage, use the default logic - return graph->create_local_wg_size(global_workgroup_size); - } -} - -utils::uvec3 choose_qparams_per_token_pick_global_wg_size( - ComputeGraph* graph, - const vkapi::ShaderInfo& shader, - const std::vector& args, - const std::vector& resize_args) { - (void)shader; - (void)resize_args; - - const ValueRef input = args.at(1).refs.at(0); - - if (graph->is_buffer_storage(input)) { - // For per-token quantization, we need one workgroup per token - // Calculate number of tokens (product of all dimensions except the last - // one) - const auto input_sizes = graph->sizes_of(input); - int64_t num_tokens = 1; - for (size_t i = 0; i < input_sizes.size() - 1; i++) { - num_tokens *= input_sizes[i]; - } - - return {static_cast(num_tokens), 1u, 1u}; - } else { - // For texture storage, use the default logic - return graph->create_global_wg_size(args.at(0).refs.at(0)); - } -} - -utils::uvec3 choose_qparams_per_token_pick_local_wg_size( - ComputeGraph* graph, - const vkapi::ShaderInfo& shader, - const utils::uvec3& global_workgroup_size, - const std::vector& args, - const std::vector& resize_args) { - (void)shader; - (void)resize_args; - - const ValueRef input = args.at(1).refs.at(0); - - if (graph->is_buffer_storage(input)) { - return {1u, 1u, 1u}; - } else { - // For texture storage, use the default logic - return graph->create_local_wg_size(global_workgroup_size); - } -} - -utils::uvec3 choose_qparams_block_wise_pick_global_wg_size( - ComputeGraph* g, - const vkapi::ShaderInfo&, - const std::vector& a, - const std::vector& r) { - const ValueRef input = a.at(2).refs.at(0); - const auto blkRef = r.at(0); - const auto inSz = g->sizes_of(input); - const auto blkList = g->get_int_list(blkRef); - - // Use same code as in add_choose_qparams_block_wise_node - utils::ivec4 block_size_vec = utils::make_whcn_ivec4(*blkList); - utils::ivec4 tensor_size_whcn = utils::make_whcn_ivec4(inSz); - - // Calculate numBlocks: ceil(tensorSize / blockSize) (both in WHCN order) - utils::ivec4 nBlk = { - (tensor_size_whcn[0] + block_size_vec[0] - 1) / block_size_vec[0], - (tensor_size_whcn[1] + block_size_vec[1] - 1) / block_size_vec[1], - (tensor_size_whcn[2] + block_size_vec[2] - 1) / block_size_vec[2], - (tensor_size_whcn[3] + block_size_vec[3] - 1) / block_size_vec[3]}; - - uint32_t nBlocks = nBlk[0] * nBlk[1] * nBlk[2] * nBlk[3]; - - // For texture storage, use more threads to better utilize GPU parallelism - // Each thread can process multiple blocks with stride - if (g->is_buffer_storage(input)) { - return {nBlocks, 1u, 1u}; - } else { - // For texture storage, use more workgroups to better utilize GPU - // Aim for ~64-256 threads per workgroup for good occupancy - uint32_t preferred_threads_per_wg = 64; - uint32_t num_workgroups = - (nBlocks + preferred_threads_per_wg - 1) / preferred_threads_per_wg; - num_workgroups = std::max(1u, std::min(num_workgroups, nBlocks)); - return {num_workgroups * preferred_threads_per_wg, 1u, 1u}; - } -} - -utils::uvec3 choose_qparams_block_wise_pick_local_wg_size( - ComputeGraph* g, - const vkapi::ShaderInfo&, - const utils::uvec3& global_wg_size, - const std::vector& a, - const std::vector&) { - const ValueRef input = a.at(2).refs.at(0); - - if (g->is_buffer_storage(input)) { - return {1u, 1u, 1u}; - } else { - // For texture storage, use 64 threads per workgroup for better occupancy - uint32_t local_size = std::min(64u, global_wg_size[0]); - return {local_size, 1u, 1u}; - } -} - vkapi::ShaderInfo pick_choose_qparams_per_row_shader( ComputeGraph* graph, const std::vector& args, @@ -222,160 +78,6 @@ utils::uvec3 pick_choose_qparams_per_row_local_wg_size( return {workers_per_output, outputs_per_wg, 1u}; } -void add_choose_qparams_tensor_node( - ComputeGraph& graph, - const ValueRef& input, - const ValueRef& quant_min, - const ValueRef& quant_max, - const ValueRef& eps, - const ValueRef& scale_out, - const ValueRef& zero_point_out) { - std::string kernel_name("choose_qparams_tensor"); - add_storage_type_suffix(kernel_name, graph.storage_type_of(input)); - add_dtype_suffix(kernel_name, graph.dtype_of(input)); - add_dtype_suffix(kernel_name, graph.dtype_of(scale_out)); - add_dtype_suffix(kernel_name, graph.dtype_of(zero_point_out)); - - // Handle optional quant_min and quant_max parameters independently - auto bounds = get_dtype_bounds(graph.dtype_of(zero_point_out)); - - int quant_min_val, quant_max_val; - - // Handle quant_min - if (graph.val_is_none(quant_min)) { - quant_min_val = bounds.first; - } else { - VK_CHECK_COND( - graph.val_is_int(quant_min), - "quant_min must be an integer, got type: ", - graph.get_val_type(quant_min)); - quant_min_val = static_cast(graph.get_int(quant_min)); - } - - // Handle quant_max - if (graph.val_is_none(quant_max)) { - quant_max_val = bounds.second; - } else { - VK_CHECK_COND( - graph.val_is_int(quant_max), - "quant_max must be an integer, got type: ", - graph.get_val_type(quant_max)); - quant_max_val = static_cast(graph.get_int(quant_max)); - } - float eps_val = static_cast(graph.get_double(eps)); - - vkapi::ParamsBindList param_ubos; - std::vector push_constants; - - if (graph.is_buffer_storage(input)) { - param_ubos = { - graph.sizes_ubo(input), - graph.strides_ubo(input), - graph.sizes_ubo(scale_out), - graph.strides_ubo(scale_out), - graph.sizes_ubo(zero_point_out), - graph.strides_ubo(zero_point_out)}; - } else { - param_ubos = { - graph.logical_limits_ubo(input), - graph.logical_limits_ubo(scale_out), - graph.logical_limits_ubo(zero_point_out)}; - } - - push_constants = { - PushConstantDataInfo(&quant_min_val, sizeof(int)), - PushConstantDataInfo(&quant_max_val, sizeof(int)), - PushConstantDataInfo(&eps_val, sizeof(float)), - }; - - graph.execute_nodes().emplace_back(new DynamicDispatchNode( - graph, - VK_KERNEL_FROM_STR(kernel_name), - choose_qparams_pick_global_wg_size, - choose_qparams_pick_local_wg_size, - // Inputs and Outputs - {{scale_out, vkapi::kWrite}, - {zero_point_out, vkapi::kWrite}, - {input, vkapi::kRead}}, - // Shader param buffers - param_ubos, - // Push Constants - push_constants, - // Specialization Constants - {}, - // Resize Args - {}, - // Resizing Logic - nullptr)); -} - -void add_choose_qparams_per_token_asymmetric_node( - ComputeGraph& graph, - const ValueRef& input, - const ValueRef& scale_out, - const ValueRef& zero_point_out) { - std::string kernel_name("choose_qparams_per_token_asymmetric"); - add_storage_type_suffix(kernel_name, graph.storage_type_of(input)); - add_dtype_suffix(kernel_name, graph.dtype_of(input)); - add_dtype_suffix(kernel_name, graph.dtype_of(scale_out)); - add_dtype_suffix(kernel_name, graph.dtype_of(zero_point_out)); - - // Calculate number of tokens (product of all dimensions except the last one) - int64_t num_tokens = 1; - const auto input_sizes = graph.sizes_of(input); - for (size_t i = 0; i < input_sizes.size() - 1; i++) { - num_tokens *= input_sizes[i]; - } - - int num_tokens_val = static_cast(num_tokens); - int quant_min_val = -128; // Fixed for asymmetric quantization - int quant_max_val = 127; // Fixed for asymmetric quantization - - vkapi::ParamsBindList param_ubos; - std::vector push_constants; - - if (graph.is_buffer_storage(input)) { - param_ubos = { - graph.sizes_ubo(input), - graph.strides_ubo(input), - graph.sizes_ubo(scale_out), - graph.strides_ubo(scale_out), - graph.sizes_ubo(zero_point_out), - graph.strides_ubo(zero_point_out)}; - } else { - param_ubos = { - graph.logical_limits_ubo(input), - graph.logical_limits_ubo(scale_out), - graph.logical_limits_ubo(zero_point_out)}; - } - - push_constants = { - PushConstantDataInfo(&num_tokens_val, sizeof(int)), - PushConstantDataInfo(&quant_min_val, sizeof(int)), - PushConstantDataInfo(&quant_max_val, sizeof(int)), - }; - - graph.execute_nodes().emplace_back(new DynamicDispatchNode( - graph, - VK_KERNEL_FROM_STR(kernel_name), - choose_qparams_per_token_pick_global_wg_size, - choose_qparams_per_token_pick_local_wg_size, - // Inputs and Outputs - {{scale_out, vkapi::kWrite}, - {zero_point_out, vkapi::kWrite}, - {input, vkapi::kRead}}, - // Shader param buffers - param_ubos, - // Push Constants - push_constants, - // Specialization Constants - {}, - // Resize Args - {}, - // Resizing Logic - nullptr)); -} - void add_choose_qparams_per_row_node( ComputeGraph& graph, const ValueRef& input, @@ -427,221 +129,6 @@ void add_choose_qparams_per_row_node( resize_choose_qparams_per_row)); } -void add_choose_qparams_block_wise_node( - ComputeGraph& graph, - ValueRef input, - ValueRef block_size, - int mapping_type, // 0 / 1 / 2 - ValueRef quant_min, - ValueRef quant_max, - ValueRef eps, - ValueRef scale_out, - ValueRef zp_out) { - const auto input_sizes = graph.sizes_of(input); - const auto block_size_list = graph.get_int_list(block_size); - - // For shader compatibility, we still need to convert to WHCN order - // but the output shape calculation is now handled correctly in resize - // function - utils::ivec4 block_size_vec = utils::make_whcn_ivec4(*block_size_list); - utils::ivec4 tensor_size_whcn = utils::make_whcn_ivec4(input_sizes); - - // Calculate numBlocks: ceil(tensorSize / blockSize) (both in WHCN order) - utils::ivec4 num_blocks_vec = { - (tensor_size_whcn[0] + block_size_vec[0] - 1) / block_size_vec[0], - (tensor_size_whcn[1] + block_size_vec[1] - 1) / block_size_vec[1], - (tensor_size_whcn[2] + block_size_vec[2] - 1) / block_size_vec[2], - (tensor_size_whcn[3] + block_size_vec[3] - 1) / block_size_vec[3]}; - - // Calculate blockStride: pre-computed linear strides for the block grid - utils::ivec4 block_stride_vec = { - 1, - num_blocks_vec[0], - num_blocks_vec[0] * num_blocks_vec[1], - num_blocks_vec[0] * num_blocks_vec[1] * num_blocks_vec[2]}; - - // Handle optional quant_min and quant_max parameters - int qmin, qmax; - if (graph.val_is_none(quant_min) || graph.val_is_none(quant_max)) { - // Use default values based on target_dtype (similar to - // _get_and_check_qmin_qmax) For now, assume int8 range as default - this - // should match the Python implementation - qmin = -128; - qmax = 127; - } else { - qmin = static_cast(graph.get_int(quant_min)); - qmax = static_cast(graph.get_int(quant_max)); - } - - float eps_val; - if (graph.val_is_none(eps)) { - // Use default eps value (similar to Python implementation) - eps_val = 1.192092896e-07f; // torch.finfo(torch.float32).eps - } else { - eps_val = static_cast(graph.get_double(eps)); - } - - // Create push constants vector - std::vector push_constants = { - PushConstantDataInfo(&block_size_vec, sizeof(block_size_vec)), - PushConstantDataInfo(&num_blocks_vec, sizeof(num_blocks_vec)), - PushConstantDataInfo(&block_stride_vec, sizeof(block_stride_vec)), - PushConstantDataInfo(&mapping_type, sizeof(int)), - PushConstantDataInfo(&qmin, sizeof(int)), - PushConstantDataInfo(&qmax, sizeof(int)), - PushConstantDataInfo(&eps_val, sizeof(float))}; - - std::string kernel_name("choose_qparams_block_wise"); - add_storage_type_suffix(kernel_name, graph.storage_type_of(input)); - add_dtype_suffix(kernel_name, graph.dtype_of(input)); - add_dtype_suffix(kernel_name, graph.dtype_of(scale_out)); - add_dtype_suffix(kernel_name, graph.dtype_of(zp_out)); - - vkapi::ParamsBindList param_ubos; - - if (graph.is_buffer_storage(input)) { - param_ubos = { - graph.sizes_ubo(input), - graph.strides_ubo(input), - graph.sizes_ubo(scale_out), - graph.strides_ubo(scale_out), - graph.sizes_ubo(zp_out), - graph.strides_ubo(zp_out)}; - } else { - // For texture input, the shader uses buffer storage for outputs - // so we need buffer UBOs for the output tensors - param_ubos = { - graph.logical_limits_ubo(input), - graph.sizes_ubo(scale_out), - graph.strides_ubo(scale_out), - graph.sizes_ubo(zp_out), - graph.strides_ubo(zp_out)}; - } - - graph.execute_nodes().emplace_back(new DynamicDispatchNode( - graph, - VK_KERNEL_FROM_STR(kernel_name), - choose_qparams_block_wise_pick_global_wg_size, - choose_qparams_block_wise_pick_local_wg_size, - // Inputs and Outputs - {{scale_out, vkapi::kWrite}, - {zp_out, vkapi::kWrite}, - {input, vkapi::kRead}}, - // Shader param buffers - param_ubos, - // Push Constants - push_constants, - // Specialization Constants - {}, - // Resize Args - {block_size}, - // Resizing Logic - nullptr)); -} - -void choose_qparams_tensor_impl( - ComputeGraph& graph, - const std::vector& args) { - int arg_idx = 0; - const ValueRef input = args[arg_idx++]; - const ValueRef quant_min = args[arg_idx++]; - const ValueRef quant_max = args[arg_idx++]; - const ValueRef eps = args[arg_idx++]; - const ValueRef dtype = args[arg_idx++]; - const ValueRef out_tuple_ref = args[arg_idx++]; - - ValueRef scale_out = kDummyValueRef; - ValueRef zero_point_out = kDummyValueRef; - - { - const ValueListPtr out_tuple = graph.get_value_list(out_tuple_ref); - scale_out = out_tuple->at(0); - zero_point_out = out_tuple->at(1); - } - - // Void the unused dtype parameter to match ATen signature - (void)dtype; - - // Check tensor types - VK_CHECK_COND(graph.val_is_tensor(input)); - VK_CHECK_COND(graph.val_is_tensor(scale_out)); - VK_CHECK_COND(graph.val_is_tensor(zero_point_out)); - - // Verify input is a floating point type - VK_CHECK_COND(graph.dtype_of(input) == vkapi::kFloat); - - // Get scale and zero point output dtypes - vkapi::ScalarType scale_out_dtype = graph.dtype_of(scale_out); - vkapi::ScalarType zero_point_out_dtype = graph.dtype_of(zero_point_out); - - // Verify supported output types for scale (fp32 only for now) - VK_CHECK_COND(scale_out_dtype == vkapi::kFloat); - - // Verify supported output types for zero point (int32, int8, fp32) - VK_CHECK_COND( - zero_point_out_dtype == vkapi::kInt || - zero_point_out_dtype == vkapi::kChar || - zero_point_out_dtype == vkapi::kFloat); - - // Check that texture storage is width packed - if (!graph.is_buffer_storage(input)) { - VK_CHECK_COND(graph.packed_dim_of(input) == WHCN::kWidthDim); - } - - add_choose_qparams_tensor_node( - graph, input, quant_min, quant_max, eps, scale_out, zero_point_out); -} - -void choose_qparams_per_token_asymmetric_impl( - ComputeGraph& graph, - const std::vector& args) { - int arg_idx = 0; - const ValueRef input = args[arg_idx++]; - const ValueRef dtype = args[arg_idx++]; - const ValueRef out_tuple_ref = args[arg_idx++]; - - ValueRef scale_out = kDummyValueRef; - ValueRef zero_point_out = kDummyValueRef; - - { - const ValueListPtr out_tuple = graph.get_value_list(out_tuple_ref); - scale_out = out_tuple->at(0); - zero_point_out = out_tuple->at(1); - } - - // Void the unused parameter to match ATen signature - (void)dtype; - - // Check tensor types - VK_CHECK_COND(graph.val_is_tensor(input)); - VK_CHECK_COND(graph.val_is_tensor(scale_out)); - VK_CHECK_COND(graph.val_is_tensor(zero_point_out)); - - // Verify input is a floating point type - VK_CHECK_COND(graph.dtype_of(input) == vkapi::kFloat); - - // Get scale and zero point output dtypes - vkapi::ScalarType scale_out_dtype = graph.dtype_of(scale_out); - vkapi::ScalarType zero_point_out_dtype = graph.dtype_of(zero_point_out); - - // Verify supported output types for scale (fp32 only for now) - VK_CHECK_COND(scale_out_dtype == vkapi::kFloat); - - // Verify supported output types for zero point (int32, int8, fp32) - VK_CHECK_COND( - zero_point_out_dtype == vkapi::kInt || - zero_point_out_dtype == vkapi::kChar || - zero_point_out_dtype == vkapi::kFloat); - - // Check that texture storage is width packed - if (!graph.is_buffer_storage(input)) { - VK_CHECK_COND(graph.packed_dim_of(input) == WHCN::kWidthDim); - } - - add_choose_qparams_per_token_asymmetric_node( - graph, input, scale_out, zero_point_out); -} - bool can_use_choose_qparams_per_row( ComputeGraph& graph, const ValueRef input, @@ -674,11 +161,13 @@ void choose_qparams_affine_impl( int arg_idx = 0; const ValueRef input = args[arg_idx++]; const ValueRef mapping_type = args[arg_idx++]; + (void)mapping_type; const ValueRef block_size = args[arg_idx++]; const ValueRef target_dtype = args[arg_idx++]; const ValueRef quant_min = args[arg_idx++]; const ValueRef quant_max = args[arg_idx++]; const ValueRef eps = args[arg_idx++]; + (void)eps; const ValueRef scale_dtype = args[arg_idx++]; const ValueRef zero_point_dtype = args[arg_idx++]; const ValueRef out_tuple_ref = args[arg_idx++]; @@ -704,59 +193,7 @@ void choose_qparams_affine_impl( graph, input, quant_min, quant_max, scale_out, zero_point_out); } - // Check tensor types - VK_CHECK_COND(graph.val_is_tensor(input)); - VK_CHECK_COND(graph.val_is_tensor(scale_out)); - VK_CHECK_COND(graph.val_is_tensor(zero_point_out)); - - // Verify input is a floating point type - VK_CHECK_COND(graph.dtype_of(input) == vkapi::kFloat); - - // Get scale and zero point dtypes from arguments - vkapi::ScalarType scale_out_dtype = graph.dtype_of(scale_out); - vkapi::ScalarType zero_point_out_dtype = graph.dtype_of(zero_point_out); - - // Verify supported output types for scale (fp32 only for now) - VK_CHECK_COND(scale_out_dtype == vkapi::kFloat); - - // Verify supported output types for zero point (int32, int8, fp32) - VK_CHECK_COND( - zero_point_out_dtype == vkapi::kInt || - zero_point_out_dtype == vkapi::kChar || - zero_point_out_dtype == vkapi::kFloat); - - // Check that texture storage is width packed - if (!graph.is_buffer_storage(input)) { - VK_CHECK_COND(graph.packed_dim_of(input) == WHCN::kWidthDim); - } - - const auto input_sizes = graph.sizes_of(input); - const auto block_size_list = graph.get_int_list(block_size); - VK_CHECK_COND(block_size_list->size() == input_sizes.size()); - - std::string mapping_type_str = graph.get_string(mapping_type); - int mapping_type_val = 0; // Default to ASYMMETRIC - - if (mapping_type_str == "ASYMMETRIC" || mapping_type_str.empty()) { - mapping_type_val = 0; // ASYMMETRIC - } else if (mapping_type_str == "SYMMETRIC") { - mapping_type_val = 1; - } else if (mapping_type_str == "SYMMETRIC_NO_CLIPPING_ERR") { - mapping_type_val = 2; - } else { - VK_THROW("Unsupported mapping_type: ", mapping_type_str); - } - - add_choose_qparams_block_wise_node( - graph, - input, - block_size, - mapping_type_val, - quant_min, - quant_max, - eps, - scale_out, - zero_point_out); + VK_THROW("Unsupported input case for choose_qparams_affine"); } void choose_qparams_per_row( @@ -769,27 +206,11 @@ void choose_qparams_per_row( const ValueRef input_scales = args[arg_idx++]; const ValueRef input_zps = args[arg_idx++]; - // ValueRef scale_out = kDummyValueRef; - // ValueRef zero_point_out = kDummyValueRef; - // - // { - // const ValueListPtr out_tuple = graph.get_value_list(out_tuple_ref); - // scale_out = out_tuple->at(0); - // zero_point_out = out_tuple->at(1); - // } - // - add_choose_qparams_per_row_node( graph, input, quant_min, quant_max, input_scales, input_zps); } REGISTER_OPERATORS { - VK_REGISTER_OP( - quantized_decomposed.choose_qparams.tensor, choose_qparams_tensor_impl); - VK_REGISTER_OP( - quantized_decomposed.choose_qparams_per_token_asymmetric.default, - choose_qparams_per_token_asymmetric_impl); - // Register the per-channel quantization operator VK_REGISTER_OP(etvk.choose_qparams_per_row.default, choose_qparams_per_row); diff --git a/backends/vulkan/runtime/graph/ops/impl/Dequantize.cpp b/backends/vulkan/runtime/graph/ops/impl/Dequantize.cpp deleted file mode 100644 index a217734653d..00000000000 --- a/backends/vulkan/runtime/graph/ops/impl/Dequantize.cpp +++ /dev/null @@ -1,843 +0,0 @@ -/* - * Copyright (c) Meta Platforms, Inc. and affiliates. - * All rights reserved. - * - * This source code is licensed under the BSD-style license found in the - * LICENSE file in the root directory of this source tree. - */ - -#include -#include -#include - -#include -#include -#include -#include - -namespace vkcompute { - -void resize_dequantize_node( - ComputeGraph* graph, - const std::vector& args, - const std::vector& extra_args) { - (void)extra_args; - - const ValueRef out = args.at(0).refs.at(0); - const ValueRef in = args.at(1).refs.at(0); - - const std::vector in_sizes = graph->sizes_of(in); - graph->virtual_resize(out, in_sizes); -} - -utils::uvec3 dequantize_per_channel_local_wg_size( - ComputeGraph* graph, - const vkapi::ShaderInfo& shader, - const utils::uvec3& global_workgroup_size, - const std::vector& args, - const std::vector& resize_args) { - (void)args; - (void)resize_args; - - const ValueRef input = args.at(1).refs.at(0); - - utils::uvec3 local_wg_size = - graph->create_local_wg_size(global_workgroup_size); - - // WORKAROUND: The CommandBuffer::dispatch function divides - // global_workgroup_size by local_workgroup_size to get the number of - // workgroups to dispatch. We need to ensure that we dispatch the correct - // number of workgroups in the Z dimension to cover all batch-channel - // combinations. - // - // If local_wg_size[2] > 1, then div_up(global_workgroup_size[2], - // local_wg_size[2]) might reduce the number of workgroups dispatched. To - // ensure we dispatch global_workgroup_size[2] workgroups in the Z dimension, - // we set local_wg_size[2] = 1. - const auto input_sizes = graph->sizes_of(input); - if (input_sizes.size() == 4 && !graph->is_buffer_storage(input) && - global_workgroup_size[2] > 1) { - local_wg_size[2] = 1; - } - - return local_wg_size; -} - -utils::uvec3 dequantize_block_wise_local_wg_size( - ComputeGraph* graph, - const vkapi::ShaderInfo& shader, - const utils::uvec3& global_workgroup_size, - const std::vector& args, - const std::vector& resize_args) { - (void)shader; - (void)resize_args; - const ValueRef input = args.at(1).refs.at(0); - - utils::uvec3 local_wg_size = - graph->create_local_wg_size(global_workgroup_size); - - // WORKAROUND: The CommandBuffer::dispatch function divides - // global_workgroup_size by local_workgroup_size to get the number of - // workgroups to dispatch. We need to ensure that we dispatch the correct - // number of workgroups in the Z dimension to cover all batch-channel - // combinations. - // - // If local_wg_size[2] > 1, then div_up(global_workgroup_size[2], - // local_wg_size[2]) might reduce the number of workgroups dispatched. To - // ensure we dispatch global_workgroup_size[2] workgroups in the Z dimension, - // we set local_wg_size[2] = 1. - const auto input_sizes = graph->sizes_of(input); - if (input_sizes.size() == 4 && !graph->is_buffer_storage(input) && - global_workgroup_size[2] > 1) { - local_wg_size[2] = 1; - } - - return local_wg_size; -} - -void add_dequantize_per_tensor_node( - ComputeGraph& graph, - const ValueRef& input, - const ValueRef& scale, - const ValueRef& zero_point, - const ValueRef& quant_min, - const ValueRef& quant_max, - const ValueRef& output) { - std::string kernel_name("dequantize_per_tensor"); - add_storage_type_suffix(kernel_name, graph.storage_type_of(input)); - add_dtype_suffix(kernel_name, graph.dtype_of(input)); - add_dtype_suffix(kernel_name, graph.dtype_of(output)); - add_dtype_suffix(kernel_name, graph.dtype_of(scale)); - add_dtype_suffix(kernel_name, graph.dtype_of(zero_point)); - - // Handle optional quant_min and quant_max parameters independently - auto bounds = get_dtype_bounds(graph.dtype_of(input)); - - int quant_min_val, quant_max_val; - - // Handle quant_min - if (graph.val_is_none(quant_min)) { - quant_min_val = bounds.first; - } else { - VK_CHECK_COND( - graph.val_is_int(quant_min), - "quant_min must be an integer, got type: ", - graph.get_val_type(quant_min)); - quant_min_val = static_cast(graph.get_int(quant_min)); - } - - // Handle quant_max - if (graph.val_is_none(quant_max)) { - quant_max_val = bounds.second; - } else { - VK_CHECK_COND( - graph.val_is_int(quant_max), - "quant_max must be an integer, got type: ", - graph.get_val_type(quant_max)); - quant_max_val = static_cast(graph.get_int(quant_max)); - } - - vkapi::ParamsBindList param_ubos; - std::vector push_constants; - - if (graph.is_buffer_storage(input)) { - param_ubos = { - graph.numel_ubo(input), - graph.sizes_ubo(input), - graph.strides_ubo(input), - graph.sizes_ubo(output), - graph.strides_ubo(output)}; - } else { - param_ubos = { - graph.logical_limits_ubo(input), graph.logical_limits_ubo(output)}; - } - - push_constants = { - PushConstantDataInfo(&quant_min_val, sizeof(int)), - PushConstantDataInfo(&quant_max_val, sizeof(int)), - }; - - vkapi::SpecVarList spec_vars = { - graph.hashed_layout_of(output), - graph.hashed_layout_of(input), - }; - - graph.execute_nodes().emplace_back(new DynamicDispatchNode( - graph, - VK_KERNEL_FROM_STR(kernel_name), - default_pick_global_wg_size, - default_pick_local_wg_size, - // Inputs and Outputs - {{output, vkapi::kWrite}, - {input, vkapi::kRead}, - {{scale, zero_point}, vkapi::kRead}}, - // Shader param buffers - param_ubos, - // Push Constants - push_constants, - // Specialization Constants - spec_vars, - // Resize Args - {}, - // Resizing Logic - resize_dequantize_node)); -} - -void add_dequantize_per_token_node( - ComputeGraph& graph, - const ValueRef& input, - const ValueRef& scale, - const ValueRef& zero_point, - const ValueRef& quant_min, - const ValueRef& quant_max, - const ValueRef& output) { - std::string kernel_name("dequantize_per_token"); - add_storage_type_suffix(kernel_name, graph.storage_type_of(input)); - add_dtype_suffix(kernel_name, graph.dtype_of(input)); - add_dtype_suffix(kernel_name, graph.dtype_of(output)); - add_dtype_suffix(kernel_name, graph.dtype_of(scale)); - add_dtype_suffix(kernel_name, graph.dtype_of(zero_point)); - - // Handle optional quant_min and quant_max parameters independently - auto bounds = get_dtype_bounds(graph.dtype_of(input)); - - int quant_min_val, quant_max_val; - - // Handle quant_min - if (graph.val_is_none(quant_min)) { - quant_min_val = bounds.first; - } else { - VK_CHECK_COND( - graph.val_is_int(quant_min), - "quant_min must be an integer, got type: ", - graph.get_val_type(quant_min)); - quant_min_val = static_cast(graph.get_int(quant_min)); - } - - // Handle quant_max - if (graph.val_is_none(quant_max)) { - quant_max_val = bounds.second; - } else { - VK_CHECK_COND( - graph.val_is_int(quant_max), - "quant_max must be an integer, got type: ", - graph.get_val_type(quant_max)); - quant_max_val = static_cast(graph.get_int(quant_max)); - } - - int num_tokens = static_cast(graph.sizes_of(scale)[0]); - - vkapi::ParamsBindList param_ubos; - std::vector push_constants; - - if (graph.is_buffer_storage(input)) { - param_ubos = { - graph.numel_ubo(input), - graph.sizes_ubo(input), - graph.strides_ubo(input), - graph.sizes_ubo(output), - graph.strides_ubo(output)}; - } else { - param_ubos = { - graph.logical_limits_ubo(input), graph.logical_limits_ubo(output)}; - } - - push_constants = { - PushConstantDataInfo(&num_tokens, sizeof(int)), - PushConstantDataInfo(&quant_min_val, sizeof(int)), - PushConstantDataInfo(&quant_max_val, sizeof(int)), - }; - - vkapi::SpecVarList spec_vars = { - graph.hashed_layout_of(output), - graph.hashed_layout_of(input), - }; - - graph.execute_nodes().emplace_back(new DynamicDispatchNode( - graph, - VK_KERNEL_FROM_STR(kernel_name), - default_pick_global_wg_size, - default_pick_local_wg_size, - // Inputs and Outputs - {{output, vkapi::kWrite}, - {input, vkapi::kRead}, - {{scale, zero_point}, vkapi::kRead}}, - // Shader param buffers - param_ubos, - // Push Constants - push_constants, - // Specialization Constants - spec_vars, - // Resize Args - {}, - // Resizing Logic - resize_dequantize_node)); -} - -void add_dequantize_per_channel_node( - ComputeGraph& graph, - const ValueRef& input, - const ValueRef& scale, - const ValueRef& zero_point, - const ValueRef& axis, - const ValueRef& quant_min, - const ValueRef& quant_max, - const ValueRef& output) { - std::string kernel_name("dequantize_per_channel"); - add_storage_type_suffix(kernel_name, graph.storage_type_of(input)); - add_dtype_suffix(kernel_name, graph.dtype_of(input)); - add_dtype_suffix(kernel_name, graph.dtype_of(output)); - add_dtype_suffix(kernel_name, graph.dtype_of(scale)); - add_dtype_suffix(kernel_name, graph.dtype_of(zero_point)); - - int axis_val = static_cast(graph.get_int(axis)); - - // Handle optional quant_min and quant_max parameters independently - auto bounds = get_dtype_bounds(graph.dtype_of(input)); - - int quant_min_val, quant_max_val; - - // Handle quant_min - if (graph.val_is_none(quant_min)) { - quant_min_val = bounds.first; - } else { - VK_CHECK_COND( - graph.val_is_int(quant_min), - "quant_min must be an integer, got type: ", - graph.get_val_type(quant_min)); - quant_min_val = static_cast(graph.get_int(quant_min)); - } - - // Handle quant_max - if (graph.val_is_none(quant_max)) { - quant_max_val = bounds.second; - } else { - VK_CHECK_COND( - graph.val_is_int(quant_max), - "quant_max must be an integer, got type: ", - graph.get_val_type(quant_max)); - quant_max_val = static_cast(graph.get_int(quant_max)); - } - - // Normalize axis and convert from NCHW to WHCN using utility functions - const auto input_sizes = graph.sizes_of(input); - const int64_t ndim = graph.dim_of(input); - - // Normalize axis to handle negative indices - axis_val = normalize(axis_val, ndim); - - // Convert from NCHW axis to WHCN axis for shader (vulkan representation) - int axis_whcn = nchw_dim_to_whcn_dim(axis_val, ndim); - - int num_channels; - if (axis_val == 0 && ndim == 4 && !graph.is_buffer_storage(input)) { - // For batch dimension dequantization in 4D tensors, pass the actual number - // of channels so the shader can correctly unfold the batch-channel folding - num_channels = static_cast(input_sizes[1]); // Channel dimension - } else { - num_channels = static_cast(input_sizes[axis_val]); - } - - vkapi::ParamsBindList param_ubos; - std::vector push_constants; - - if (graph.is_buffer_storage(input)) { - param_ubos = { - graph.numel_ubo(input), - graph.sizes_ubo(input), - graph.strides_ubo(input), - graph.sizes_ubo(output), - graph.strides_ubo(output)}; - } else { - param_ubos = { - graph.logical_limits_ubo(input), graph.logical_limits_ubo(output)}; - } - - push_constants = { - PushConstantDataInfo(&axis_whcn, sizeof(int)), - PushConstantDataInfo(&num_channels, sizeof(int)), - PushConstantDataInfo(&quant_min_val, sizeof(int)), - PushConstantDataInfo(&quant_max_val, sizeof(int)), - }; - - vkapi::SpecVarList spec_vars = { - graph.hashed_layout_of(output), - graph.hashed_layout_of(input), - }; - - graph.execute_nodes().emplace_back(new DynamicDispatchNode( - graph, - VK_KERNEL_FROM_STR(kernel_name), - default_pick_global_wg_size, - dequantize_per_channel_local_wg_size, - // Inputs and Outputs - {{output, vkapi::kWrite}, - {input, vkapi::kRead}, - {{scale, zero_point}, vkapi::kRead}}, - // Shader param buffers - param_ubos, - // Push Constants - push_constants, - // Specialization Constants - spec_vars, - // Resize Args - {}, - // Resizing Logic - resize_dequantize_node)); -} - -void add_dequantize_block_wise_node( - ComputeGraph& graph, - const ValueRef& input, - const ValueRef& block_size, - const ValueRef& scale, - const ValueRef& zero_point, - const ValueRef& quant_min, - const ValueRef& quant_max, - const ValueRef& output) { - std::string kernel_name("dequantize_block_wise"); - add_storage_type_suffix(kernel_name, graph.storage_type_of(input)); - add_dtype_suffix(kernel_name, graph.dtype_of(input)); - add_dtype_suffix(kernel_name, graph.dtype_of(output)); - add_dtype_suffix(kernel_name, graph.dtype_of(scale)); - add_dtype_suffix(kernel_name, graph.dtype_of(zero_point)); - - // Handle optional quant_min and quant_max parameters independently - auto bounds = get_dtype_bounds(graph.dtype_of(input)); - - int quant_min_val, quant_max_val; - - // Handle quant_min - if (graph.val_is_none(quant_min)) { - quant_min_val = bounds.first; - } else { - VK_CHECK_COND( - graph.val_is_int(quant_min), - "quant_min must be an integer, got type: ", - graph.get_val_type(quant_min)); - quant_min_val = static_cast(graph.get_int(quant_min)); - } - - // Handle quant_max - if (graph.val_is_none(quant_max)) { - quant_max_val = bounds.second; - } else { - VK_CHECK_COND( - graph.val_is_int(quant_max), - "quant_max must be an integer, got type: ", - graph.get_val_type(quant_max)); - quant_max_val = static_cast(graph.get_int(quant_max)); - } - - const auto input_sizes = graph.sizes_of(input); - const auto block_size_list = graph.get_int_list(block_size); - - // Convert dimensions to WHCN order for shader - utils::ivec4 block_size_vec = utils::make_whcn_ivec4(*block_size_list); - utils::ivec4 tensor_size_whcn = utils::make_whcn_ivec4(input_sizes); - - // Calculate numBlocks: tensorSize / blockSize (both in WHCN order) - utils::ivec4 num_blocks_vec = { - tensor_size_whcn[0] / block_size_vec[0], - tensor_size_whcn[1] / block_size_vec[1], - tensor_size_whcn[2] / block_size_vec[2], - tensor_size_whcn[3] / block_size_vec[3]}; - - // Calculate blockStride: pre-computed linear strides for the block grid - utils::ivec4 block_stride_vec = { - 1, - num_blocks_vec[0], - num_blocks_vec[0] * num_blocks_vec[1], - num_blocks_vec[0] * num_blocks_vec[1] * num_blocks_vec[2]}; - - vkapi::ParamsBindList param_ubos; - std::vector push_constants; - - if (graph.is_buffer_storage(input)) { - param_ubos = { - graph.numel_ubo(input), - graph.sizes_ubo(input), - graph.strides_ubo(input), - graph.sizes_ubo(output), - graph.strides_ubo(output)}; - } else { - param_ubos = { - graph.logical_limits_ubo(input), graph.logical_limits_ubo(output)}; - } - - push_constants = { - PushConstantDataInfo(&block_size_vec, sizeof(block_size_vec)), - PushConstantDataInfo(&num_blocks_vec, sizeof(num_blocks_vec)), - PushConstantDataInfo(&block_stride_vec, sizeof(block_stride_vec)), - PushConstantDataInfo(&quant_min_val, sizeof(int)), - PushConstantDataInfo(&quant_max_val, sizeof(int)), - }; - - vkapi::SpecVarList spec_vars = { - graph.hashed_layout_of(output), - graph.hashed_layout_of(input), - }; - - graph.execute_nodes().emplace_back(new DynamicDispatchNode( - graph, - VK_KERNEL_FROM_STR(kernel_name), - default_pick_global_wg_size, - dequantize_block_wise_local_wg_size, - // Inputs and Outputs - {{output, vkapi::kWrite}, - {input, vkapi::kRead}, - {{scale, zero_point}, vkapi::kRead}}, - // Shader param buffers - param_ubos, - // Push Constants - push_constants, - // Specialization Constants - spec_vars, - // Resize Args - {}, - // Resizing Logic - resize_dequantize_node)); -} - -void dequantize_per_tensor_impl( - ComputeGraph& graph, - const std::vector& args) { - int arg_idx = 0; - const ValueRef input = args[arg_idx++]; - const ValueRef scale = args[arg_idx++]; - const ValueRef zero_point = args[arg_idx++]; - const ValueRef quant_min = args[arg_idx++]; - const ValueRef quant_max = args[arg_idx++]; - const ValueRef dtype = args[arg_idx++]; - const ValueRef output_dtype = args[arg_idx++]; - const ValueRef output = args[arg_idx++]; - - // Suppress unused variable warnings - dtype and output_dtype are inferred - (void)dtype; - (void)output_dtype; - - // Check tensor types - VK_CHECK_COND(graph.val_is_tensor(input)); - VK_CHECK_COND(graph.val_is_tensor(scale)); - VK_CHECK_COND(graph.val_is_tensor(zero_point)); - VK_CHECK_COND(graph.val_is_tensor(output)); - - // Verify input is an integer type - VK_CHECK_COND( - graph.dtype_of(input) == vkapi::kByte || - graph.dtype_of(input) == vkapi::kChar || - graph.dtype_of(input) == vkapi::kInt); - - // Get scale and zero point dtypes - vkapi::ScalarType scale_dtype = graph.dtype_of(scale); - vkapi::ScalarType zero_point_dtype = graph.dtype_of(zero_point); - - // Verify supported types for scale (fp32 only for now) - VK_CHECK_COND(scale_dtype == vkapi::kFloat); - - // Verify supported types for zero point (int32, int8, fp32) - VK_CHECK_COND( - zero_point_dtype == vkapi::kInt || zero_point_dtype == vkapi::kChar || - zero_point_dtype == vkapi::kFloat); - - // Check that scale and zero_point have buffer storage and width packing - VK_CHECK_COND(graph.is_buffer_storage(scale)); - VK_CHECK_COND(graph.packed_dim_of(scale) == WHCN::kWidthDim); - VK_CHECK_COND(graph.is_buffer_storage(zero_point)); - VK_CHECK_COND(graph.packed_dim_of(zero_point) == WHCN::kWidthDim); - - // Check that tensors with texture storage have standard axis map - if (!graph.is_buffer_storage(input)) { - VK_CHECK_COND(graph.has_standard_axis_map(input)); - } - if (!graph.is_buffer_storage(output)) { - VK_CHECK_COND(graph.has_standard_axis_map(output)); - } - - add_dequantize_per_tensor_node( - graph, input, scale, zero_point, quant_min, quant_max, output); -} - -void dequantize_per_token_impl( - ComputeGraph& graph, - const std::vector& args) { - int arg_idx = 0; - const ValueRef input = args[arg_idx++]; - const ValueRef scale = args[arg_idx++]; - const ValueRef zero_point = args[arg_idx++]; - const ValueRef quant_min = args[arg_idx++]; - const ValueRef quant_max = args[arg_idx++]; - const ValueRef dtype = args[arg_idx++]; - const ValueRef output_dtype = args[arg_idx++]; - const ValueRef output = args[arg_idx++]; - - // Suppress unused variable warnings - dtype and output_dtype are inferred - (void)dtype; - (void)output_dtype; - - // Check tensor types - VK_CHECK_COND(graph.val_is_tensor(input)); - VK_CHECK_COND(graph.val_is_tensor(scale)); - VK_CHECK_COND(graph.val_is_tensor(zero_point)); - VK_CHECK_COND(graph.val_is_tensor(output)); - - // Verify input is an integer type - VK_CHECK_COND( - graph.dtype_of(input) == vkapi::kByte || - graph.dtype_of(input) == vkapi::kChar || - graph.dtype_of(input) == vkapi::kInt); - - // Get scale and zero point dtypes - vkapi::ScalarType scale_dtype = graph.dtype_of(scale); - vkapi::ScalarType zero_point_dtype = graph.dtype_of(zero_point); - - // Verify supported types for scale (fp32 only for now) - VK_CHECK_COND(scale_dtype == vkapi::kFloat); - - // Verify supported types for zero point (int32, int8, fp32) - VK_CHECK_COND( - zero_point_dtype == vkapi::kInt || zero_point_dtype == vkapi::kChar || - zero_point_dtype == vkapi::kFloat); - - // Check that scale and zero_point have buffer storage and width packing - VK_CHECK_COND(graph.is_buffer_storage(scale)); - VK_CHECK_COND(graph.packed_dim_of(scale) == WHCN::kWidthDim); - VK_CHECK_COND(graph.is_buffer_storage(zero_point)); - VK_CHECK_COND(graph.packed_dim_of(zero_point) == WHCN::kWidthDim); - - // Check that tensors with texture storage have standard axis map - if (!graph.is_buffer_storage(input)) { - VK_CHECK_COND(graph.has_standard_axis_map(input)); - } - if (!graph.is_buffer_storage(output)) { - VK_CHECK_COND(graph.has_standard_axis_map(output)); - } - - // Calculate number of tokens (product of all dimensions except the last one) - int64_t num_tokens = 1; - const auto input_sizes = graph.sizes_of(input); - for (size_t i = 0; i < input_sizes.size() - 1; i++) { - num_tokens *= input_sizes[i]; - } - - const auto scale_sizes = graph.sizes_of(scale); - const auto zero_point_sizes = graph.sizes_of(zero_point); - - // Calculate total number of elements in scale and zero_point tensors - int64_t scale_numel = 1; - for (size_t i = 0; i < scale_sizes.size(); i++) { - scale_numel *= scale_sizes[i]; - } - - int64_t zero_point_numel = 1; - for (size_t i = 0; i < zero_point_sizes.size(); i++) { - zero_point_numel *= zero_point_sizes[i]; - } - - // Check that the total number of elements matches num_tokens - // This allows for both 1D tensors (size [num_tokens]) and reshaped tensors - // (size [num_tokens, 1]) - VK_CHECK_COND(scale_numel == num_tokens); - VK_CHECK_COND(zero_point_numel == num_tokens); - - add_dequantize_per_token_node( - graph, input, scale, zero_point, quant_min, quant_max, output); -} - -void dequantize_per_channel_impl( - ComputeGraph& graph, - const std::vector& args) { - int arg_idx = 0; - const ValueRef input = args[arg_idx++]; - const ValueRef scale = args[arg_idx++]; - const ValueRef zero_point = args[arg_idx++]; - const ValueRef axis = args[arg_idx++]; - const ValueRef quant_min = args[arg_idx++]; - const ValueRef quant_max = args[arg_idx++]; - const ValueRef dtype = args[arg_idx++]; - const ValueRef output_dtype = args[arg_idx++]; - const ValueRef output = args[arg_idx++]; - - // Suppress unused variable warnings - dtype and output_dtype are inferred - (void)dtype; - (void)output_dtype; - - // Check tensor types - VK_CHECK_COND(graph.val_is_tensor(input)); - VK_CHECK_COND(graph.val_is_tensor(scale)); - VK_CHECK_COND(graph.val_is_tensor(zero_point)); - VK_CHECK_COND(graph.val_is_tensor(output)); - - // Verify input is an integer type - VK_CHECK_COND( - graph.dtype_of(input) == vkapi::kByte || - graph.dtype_of(input) == vkapi::kChar || - graph.dtype_of(input) == vkapi::kInt); - - // Get scale and zero point dtypes - vkapi::ScalarType scale_dtype = graph.dtype_of(scale); - vkapi::ScalarType zero_point_dtype = graph.dtype_of(zero_point); - - // Verify supported types for scale (fp32 only for now) - VK_CHECK_COND(scale_dtype == vkapi::kFloat); - - // Verify supported types for zero point (int32, int8, fp32) - VK_CHECK_COND( - zero_point_dtype == vkapi::kInt || zero_point_dtype == vkapi::kChar || - zero_point_dtype == vkapi::kFloat); - - // Check that scale and zero_point have buffer storage and width packing - VK_CHECK_COND(graph.is_buffer_storage(scale)); - VK_CHECK_COND(graph.packed_dim_of(scale) == WHCN::kWidthDim); - VK_CHECK_COND(graph.is_buffer_storage(zero_point)); - VK_CHECK_COND(graph.packed_dim_of(zero_point) == WHCN::kWidthDim); - - // Check that tensors with texture storage have standard axis map - if (!graph.is_buffer_storage(input)) { - VK_CHECK_COND(graph.has_standard_axis_map(input)); - } - if (!graph.is_buffer_storage(output)) { - VK_CHECK_COND(graph.has_standard_axis_map(output)); - } - - // Normalize axis - int axis_val = static_cast(graph.get_int(axis)); - const auto input_sizes = graph.sizes_of(input); - int ndim = graph.dim_of(input); - if (axis_val < 0) { - axis_val += ndim; - } - - // Verify axis is valid - VK_CHECK_COND(axis_val >= 0 && axis_val < ndim); - - // Get number of channels along the specified axis - int64_t num_channels = input_sizes[axis_val]; - - const auto scale_sizes = graph.sizes_of(scale); - const auto zero_point_sizes = graph.sizes_of(zero_point); - - // Calculate total number of elements in scale and zero_point tensors - int64_t scale_numel = 1; - for (size_t i = 0; i < scale_sizes.size(); i++) { - scale_numel *= scale_sizes[i]; - } - - int64_t zero_point_numel = 1; - for (size_t i = 0; i < zero_point_sizes.size(); i++) { - zero_point_numel *= zero_point_sizes[i]; - } - - // Check that the total number of elements matches num_channels - VK_CHECK_COND(scale_numel == num_channels); - VK_CHECK_COND(zero_point_numel == num_channels); - - add_dequantize_per_channel_node( - graph, input, scale, zero_point, axis, quant_min, quant_max, output); -} - -void dequantize_affine_impl( - ComputeGraph& graph, - const std::vector& args) { - int arg_idx = 0; - const ValueRef input = args[arg_idx++]; - const ValueRef block_size = args[arg_idx++]; - const ValueRef scale = args[arg_idx++]; - const ValueRef zero_point = args[arg_idx++]; - const ValueRef input_dtype = args[arg_idx++]; - const ValueRef quant_min = args[arg_idx++]; - const ValueRef quant_max = args[arg_idx++]; - const ValueRef output_dtype = args[arg_idx++]; - const ValueRef output = args[arg_idx++]; - - // Suppress unused variable warnings - (void)input_dtype; - (void)output_dtype; - - // Check tensor types - VK_CHECK_COND(graph.val_is_tensor(input)); - VK_CHECK_COND(graph.val_is_tensor(scale)); - VK_CHECK_COND(graph.val_is_tensor(zero_point)); - VK_CHECK_COND(graph.val_is_tensor(output)); - - // Verify input is an integer type - VK_CHECK_COND( - graph.dtype_of(input) == vkapi::kByte || - graph.dtype_of(input) == vkapi::kChar || - graph.dtype_of(input) == vkapi::kInt); - - // Get scale and zero point dtypes - vkapi::ScalarType scale_dtype = graph.dtype_of(scale); - vkapi::ScalarType zero_point_dtype = graph.dtype_of(zero_point); - - // Verify supported types for scale (fp32 only for now) - VK_CHECK_COND(scale_dtype == vkapi::kFloat); - - // Verify supported types for zero point (int32, int8, fp32) - VK_CHECK_COND( - zero_point_dtype == vkapi::kInt || zero_point_dtype == vkapi::kChar || - zero_point_dtype == vkapi::kFloat); - - // Check that scale and zero_point have buffer storage and width packing - VK_CHECK_COND(graph.is_buffer_storage(scale)); - VK_CHECK_COND(graph.packed_dim_of(scale) == WHCN::kWidthDim); - VK_CHECK_COND(graph.is_buffer_storage(zero_point)); - VK_CHECK_COND(graph.packed_dim_of(zero_point) == WHCN::kWidthDim); - - // Check that tensors with texture storage have standard axis map - if (!graph.is_buffer_storage(input)) { - VK_CHECK_COND(graph.has_standard_axis_map(input)); - } - if (!graph.is_buffer_storage(output)) { - VK_CHECK_COND(graph.has_standard_axis_map(output)); - } - - // Verify block_size is valid (each dimension must divide evenly into input - // size) - const auto input_sizes = graph.sizes_of(input); - const auto block_size_list = graph.get_int_list(block_size); - VK_CHECK_COND(block_size_list->size() == input_sizes.size()); - - for (size_t i = 0; i < input_sizes.size(); i++) { - if ((*block_size_list)[i] > 1) { - VK_CHECK_COND( - input_sizes[i] % (*block_size_list)[i] == 0, - "Input size at dimension ", - i, - " (", - input_sizes[i], - ") must be divisible by block_size at dimension ", - i, - " (", - (*block_size_list)[i], - ")"); - } - } - - add_dequantize_block_wise_node( - graph, - input, - block_size, - scale, - zero_point, - quant_min, - quant_max, - output); -} - -REGISTER_OPERATORS { - VK_REGISTER_OP( - quantized_decomposed.dequantize_per_tensor.tensor, - dequantize_per_tensor_impl); - VK_REGISTER_OP( - quantized_decomposed.dequantize_per_token.default, - dequantize_per_token_impl); - VK_REGISTER_OP( - quantized_decomposed.dequantize_per_channel.default, - dequantize_per_channel_impl); - - // TorchAO affine dequantization operators - VK_REGISTER_OP(torchao.dequantize_affine.default, dequantize_affine_impl); -} - -} // namespace vkcompute diff --git a/backends/vulkan/runtime/graph/ops/impl/Quantize.cpp b/backends/vulkan/runtime/graph/ops/impl/Quantize.cpp deleted file mode 100644 index 88f77261f4f..00000000000 --- a/backends/vulkan/runtime/graph/ops/impl/Quantize.cpp +++ /dev/null @@ -1,836 +0,0 @@ -/* - * Copyright (c) Meta Platforms, Inc. and affiliates. - * All rights reserved. - * - * This source code is licensed under the BSD-style license found in the - * LICENSE file in the root directory of this source tree. - */ - -#include -#include -#include - -#include -#include - -#include - -namespace vkcompute { - -void resize_quantize_node( - ComputeGraph* graph, - const std::vector& args, - const std::vector& extra_args) { - (void)extra_args; - - const ValueRef out = args.at(0).refs.at(0); - const ValueRef in = args.at(1).refs.at(0); - - const std::vector in_sizes = graph->sizes_of(in); - graph->virtual_resize(out, in_sizes); -} - -utils::uvec3 quantize_per_channel_local_wg_size( - ComputeGraph* graph, - const vkapi::ShaderInfo& shader, - const utils::uvec3& global_workgroup_size, - const std::vector& args, - const std::vector& resize_args) { - (void)shader; - (void)args; - (void)resize_args; - - const ValueRef input = args.at(1).refs.at(0); - - utils::uvec3 local_wg_size = - graph->create_local_wg_size(global_workgroup_size); - - // WORKAROUND: The CommandBuffer::dispatch function divides - // global_workgroup_size by local_workgroup_size to get the number of - // workgroups to dispatch. For per-channel quantization along the batch axis, - // we need to ensure that we dispatch the correct number of workgroups in the - // Z dimension to cover all batch-channel combinations. - // - // If local_wg_size[2] > 1, then div_up(global_workgroup_size[2], - // local_wg_size[2]) might reduce the number of workgroups dispatched. To - // ensure we dispatch global_workgroup_size[2] workgroups in the Z dimension, - // we set local_wg_size[2] = 1. - const auto input_sizes = graph->sizes_of(input); - if (input_sizes.size() == 4 && !graph->is_buffer_storage(input) && - global_workgroup_size[2] > 1) { - local_wg_size[2] = 1; - } - - return local_wg_size; -} - -utils::uvec3 quantize_block_wise_local_wg_size( - ComputeGraph* graph, - const vkapi::ShaderInfo& shader, - const utils::uvec3& global_workgroup_size, - const std::vector& args, - const std::vector& resize_args) { - (void)shader; - (void)resize_args; - const ValueRef input = args.at(1).refs.at(0); - - utils::uvec3 local_wg_size = - graph->create_local_wg_size(global_workgroup_size); - - // WORKAROUND: The CommandBuffer::dispatch function divides - // global_workgroup_size by local_workgroup_size to get the number of - // workgroups to dispatch. For per-channel quantization along the batch axis, - // we need to ensure that we dispatch the correct number of workgroups in the - // Z dimension to cover all batch-channel combinations. - // - // If local_wg_size[2] > 1, then div_up(global_workgroup_size[2], - // local_wg_size[2]) might reduce the number of workgroups dispatched. To - // ensure we dispatch global_workgroup_size[2] workgroups in the Z dimension, - // we set local_wg_size[2] = 1. - const auto input_sizes = graph->sizes_of(input); - if (input_sizes.size() == 4 && !graph->is_buffer_storage(input) && - global_workgroup_size[2] > 1) { - local_wg_size[2] = 1; - } - - return local_wg_size; -} - -void add_quantize_per_tensor_node( - ComputeGraph& graph, - const ValueRef& input, - const ValueRef& scale, - const ValueRef& zero_point, - const ValueRef& quant_min, - const ValueRef& quant_max, - const ValueRef& output) { - std::string kernel_name("quantize_per_tensor"); - add_storage_type_suffix(kernel_name, graph.storage_type_of(input)); - add_dtype_suffix(kernel_name, graph.dtype_of(input)); - add_dtype_suffix(kernel_name, graph.dtype_of(output)); - add_dtype_suffix(kernel_name, graph.dtype_of(scale)); - add_dtype_suffix(kernel_name, graph.dtype_of(zero_point)); - - // Handle optional quant_min and quant_max parameters independently - auto bounds = get_dtype_bounds(graph.dtype_of(output)); - - int quant_min_val, quant_max_val; - - // Handle quant_min - if (graph.val_is_none(quant_min)) { - quant_min_val = bounds.first; - } else { - VK_CHECK_COND( - graph.val_is_int(quant_min), - "quant_min must be an integer, got type: ", - graph.get_val_type(quant_min)); - quant_min_val = static_cast(graph.get_int(quant_min)); - } - - // Handle quant_max - if (graph.val_is_none(quant_max)) { - quant_max_val = bounds.second; - } else { - VK_CHECK_COND( - graph.val_is_int(quant_max), - "quant_max must be an integer, got type: ", - graph.get_val_type(quant_max)); - quant_max_val = static_cast(graph.get_int(quant_max)); - } - - vkapi::ParamsBindList param_ubos; - std::vector push_constants; - - if (graph.is_buffer_storage(input)) { - param_ubos = { - graph.numel_ubo(input), - graph.sizes_ubo(input), - graph.strides_ubo(input), - graph.sizes_ubo(output), - graph.strides_ubo(output)}; - } else { - param_ubos = { - graph.logical_limits_ubo(input), graph.logical_limits_ubo(output)}; - } - - push_constants = { - PushConstantDataInfo(&quant_min_val, sizeof(int)), - PushConstantDataInfo(&quant_max_val, sizeof(int)), - }; - - vkapi::SpecVarList spec_vars = { - graph.hashed_layout_of(output), - graph.hashed_layout_of(input), - }; - - graph.execute_nodes().emplace_back(new DynamicDispatchNode( - graph, - VK_KERNEL_FROM_STR(kernel_name), - default_pick_global_wg_size, - default_pick_local_wg_size, - // Inputs and Outputs - {{output, vkapi::kWrite}, - {input, vkapi::kRead}, - {{scale, zero_point}, vkapi::kRead}}, - // Shader param buffers - param_ubos, - // Push Constants - push_constants, - // Specialization Constants - spec_vars, - // Resize Args - {}, - // Resizing Logic - resize_quantize_node)); -} - -void add_quantize_per_token_node( - ComputeGraph& graph, - const ValueRef& input, - const ValueRef& scale, - const ValueRef& zero_point, - const ValueRef& quant_min, - const ValueRef& quant_max, - const ValueRef& output) { - std::string kernel_name("quantize_per_token"); - add_storage_type_suffix(kernel_name, graph.storage_type_of(input)); - add_dtype_suffix(kernel_name, graph.dtype_of(input)); - add_dtype_suffix(kernel_name, graph.dtype_of(output)); - add_dtype_suffix(kernel_name, graph.dtype_of(scale)); - add_dtype_suffix(kernel_name, graph.dtype_of(zero_point)); - - // Handle optional quant_min and quant_max parameters independently - auto bounds = get_dtype_bounds(graph.dtype_of(output)); - - int quant_min_val, quant_max_val; - - // Handle quant_min - if (graph.val_is_none(quant_min)) { - quant_min_val = bounds.first; - } else { - VK_CHECK_COND( - graph.val_is_int(quant_min), - "quant_min must be an integer, got type: ", - graph.get_val_type(quant_min)); - quant_min_val = static_cast(graph.get_int(quant_min)); - } - - // Handle quant_max - if (graph.val_is_none(quant_max)) { - quant_max_val = bounds.second; - } else { - VK_CHECK_COND( - graph.val_is_int(quant_max), - "quant_max must be an integer, got type: ", - graph.get_val_type(quant_max)); - quant_max_val = static_cast(graph.get_int(quant_max)); - } - - int num_tokens = static_cast(graph.sizes_of(scale)[0]); - - vkapi::ParamsBindList param_ubos; - std::vector push_constants; - - if (graph.is_buffer_storage(input)) { - param_ubos = { - graph.numel_ubo(input), - graph.sizes_ubo(input), - graph.strides_ubo(input), - graph.sizes_ubo(output), - graph.strides_ubo(output), - }; - push_constants = { - PushConstantDataInfo(&num_tokens, sizeof(int)), - PushConstantDataInfo(&quant_min_val, sizeof(int)), - PushConstantDataInfo(&quant_max_val, sizeof(int)), - }; - } else { - param_ubos = { - graph.logical_limits_ubo(input), - graph.logical_limits_ubo(output), - }; - push_constants = { - PushConstantDataInfo(&num_tokens, sizeof(int)), - PushConstantDataInfo(&quant_min_val, sizeof(int)), - PushConstantDataInfo(&quant_max_val, sizeof(int)), - }; - } - - vkapi::SpecVarList spec_vars = { - graph.hashed_layout_of(output), - graph.hashed_layout_of(input), - }; - - graph.execute_nodes().emplace_back(new DynamicDispatchNode( - graph, - VK_KERNEL_FROM_STR(kernel_name), - default_pick_global_wg_size, - default_pick_local_wg_size, - // Inputs and Outputs - {{output, vkapi::kWrite}, - {input, vkapi::kRead}, - {{scale, zero_point}, vkapi::kRead}}, - // Shader param buffers - param_ubos, - // Push Constants - push_constants, - // Specialization Constants - spec_vars, - // Resize Args - {}, - // Resizing Logic - resize_quantize_node)); -} - -void add_quantize_per_channel_node( - ComputeGraph& graph, - const ValueRef& input, - const ValueRef& scale, - const ValueRef& zero_point, - const ValueRef& axis, - const ValueRef& quant_min, - const ValueRef& quant_max, - const ValueRef& output) { - std::string kernel_name("quantize_per_channel"); - add_storage_type_suffix(kernel_name, graph.storage_type_of(input)); - add_dtype_suffix(kernel_name, graph.dtype_of(input)); - add_dtype_suffix(kernel_name, graph.dtype_of(output)); - add_dtype_suffix(kernel_name, graph.dtype_of(scale)); - add_dtype_suffix(kernel_name, graph.dtype_of(zero_point)); - - int axis_val = static_cast(graph.get_int(axis)); - - // Handle optional quant_min and quant_max parameters independently - auto bounds = get_dtype_bounds(graph.dtype_of(output)); - - int quant_min_val, quant_max_val; - - // Handle quant_min - if (graph.val_is_none(quant_min)) { - quant_min_val = bounds.first; - } else { - VK_CHECK_COND( - graph.val_is_int(quant_min), - "quant_min must be an integer, got type: ", - graph.get_val_type(quant_min)); - quant_min_val = static_cast(graph.get_int(quant_min)); - } - - // Handle quant_max - if (graph.val_is_none(quant_max)) { - quant_max_val = bounds.second; - } else { - VK_CHECK_COND( - graph.val_is_int(quant_max), - "quant_max must be an integer, got type: ", - graph.get_val_type(quant_max)); - quant_max_val = static_cast(graph.get_int(quant_max)); - } - - // Normalize axis and convert from NCHW to WHCN using utility functions - const auto input_sizes = graph.sizes_of(input); - const int64_t ndim = graph.dim_of(input); - - // Normalize axis to handle negative indices - axis_val = normalize(axis_val, ndim); - - // Convert from NCHW axis to WHCN axis for shader (vulkan representation) - int axis_whcn = nchw_dim_to_whcn_dim(axis_val, ndim); - - int num_channels; - if (axis_val == 0 && ndim == 4 && !graph.is_buffer_storage(input)) { - // For batch dimension quantization in 4D tensors, pass the actual number of - // channels so the shader can correctly unfold the batch-channel folding - num_channels = static_cast(input_sizes[1]); // Channel dimension - } else { - num_channels = static_cast(input_sizes[axis_val]); - } - - vkapi::ParamsBindList param_ubos; - std::vector push_constants; - - if (graph.is_buffer_storage(input)) { - param_ubos = { - graph.numel_ubo(input), - graph.sizes_ubo(input), - graph.strides_ubo(input), - graph.sizes_ubo(output), - graph.strides_ubo(output), - }; - push_constants = { - PushConstantDataInfo(&axis_whcn, sizeof(int)), - PushConstantDataInfo(&num_channels, sizeof(int)), - PushConstantDataInfo(&quant_min_val, sizeof(int)), - PushConstantDataInfo(&quant_max_val, sizeof(int)), - }; - } else { - param_ubos = { - graph.logical_limits_ubo(input), - graph.logical_limits_ubo(output), - }; - push_constants = { - PushConstantDataInfo(&axis_whcn, sizeof(int)), - PushConstantDataInfo(&num_channels, sizeof(int)), - PushConstantDataInfo(&quant_min_val, sizeof(int)), - PushConstantDataInfo(&quant_max_val, sizeof(int)), - }; - } - - vkapi::SpecVarList spec_vars = { - graph.hashed_layout_of(output), - graph.hashed_layout_of(input), - }; - - graph.execute_nodes().emplace_back(new DynamicDispatchNode( - graph, - VK_KERNEL_FROM_STR(kernel_name), - default_pick_global_wg_size, - quantize_per_channel_local_wg_size, - // Inputs and Outputs - {{output, vkapi::kWrite}, - {input, vkapi::kRead}, - {{scale, zero_point}, vkapi::kRead}}, - // Shader param buffers - param_ubos, - // Push Constants - push_constants, - // Specialization Constants - spec_vars, - // Resize Args - {}, - // Resizing Logic - resize_quantize_node)); -} - -void add_quantize_block_wise_node( - ComputeGraph& graph, - const ValueRef& input, - const ValueRef& block_size, - const ValueRef& scale, - const ValueRef& zero_point, - const ValueRef& quant_min, - const ValueRef& quant_max, - const ValueRef& output) { - std::string kernel_name("quantize_block_wise"); - add_storage_type_suffix(kernel_name, graph.storage_type_of(input)); - add_dtype_suffix(kernel_name, graph.dtype_of(input)); - add_dtype_suffix(kernel_name, graph.dtype_of(output)); - add_dtype_suffix(kernel_name, graph.dtype_of(scale)); - add_dtype_suffix(kernel_name, graph.dtype_of(zero_point)); - - // Handle optional quant_min and quant_max parameters independently - auto bounds = get_dtype_bounds(graph.dtype_of(output)); - - int quant_min_val, quant_max_val; - - // Handle quant_min - if (graph.val_is_none(quant_min)) { - quant_min_val = bounds.first; - } else { - VK_CHECK_COND( - graph.val_is_int(quant_min), - "quant_min must be an integer, got type: ", - graph.get_val_type(quant_min)); - quant_min_val = static_cast(graph.get_int(quant_min)); - } - - // Handle quant_max - if (graph.val_is_none(quant_max)) { - quant_max_val = bounds.second; - } else { - VK_CHECK_COND( - graph.val_is_int(quant_max), - "quant_max must be an integer, got type: ", - graph.get_val_type(quant_max)); - quant_max_val = static_cast(graph.get_int(quant_max)); - } - - const auto input_sizes = graph.sizes_of(input); - const auto block_size_list = graph.get_int_list(block_size); - - // Convert PyTorch dimensions to WHCN order for shader - utils::ivec4 block_size_vec = utils::make_whcn_ivec4(*block_size_list); - utils::ivec4 tensor_size_whcn = utils::make_whcn_ivec4(input_sizes); - - // Calculate numBlocks: tensorSize / blockSize (both in WHCN order) - utils::ivec4 num_blocks_vec = { - tensor_size_whcn[0] / block_size_vec[0], - tensor_size_whcn[1] / block_size_vec[1], - tensor_size_whcn[2] / block_size_vec[2], - tensor_size_whcn[3] / block_size_vec[3]}; - - // Calculate blockStride: pre-computed linear strides for the block grid - utils::ivec4 block_stride_vec = { - 1, - num_blocks_vec[0], - num_blocks_vec[0] * num_blocks_vec[1], - num_blocks_vec[0] * num_blocks_vec[1] * num_blocks_vec[2]}; - - vkapi::ParamsBindList param_ubos; - std::vector push_constants; - - if (graph.is_buffer_storage(input)) { - param_ubos = { - graph.numel_ubo(input), - graph.sizes_ubo(input), - graph.strides_ubo(input), - graph.sizes_ubo(output), - graph.strides_ubo(output)}; - } else { - param_ubos = { - graph.logical_limits_ubo(input), graph.logical_limits_ubo(output)}; - } - - push_constants = { - PushConstantDataInfo(&block_size_vec, sizeof(block_size_vec)), - PushConstantDataInfo(&num_blocks_vec, sizeof(num_blocks_vec)), - PushConstantDataInfo(&block_stride_vec, sizeof(block_stride_vec)), - PushConstantDataInfo(&quant_min_val, sizeof(int)), - PushConstantDataInfo(&quant_max_val, sizeof(int)), - }; - - vkapi::SpecVarList spec_vars = { - graph.hashed_layout_of(output), - graph.hashed_layout_of(input), - }; - - graph.execute_nodes().emplace_back(new DynamicDispatchNode( - graph, - VK_KERNEL_FROM_STR(kernel_name), - default_pick_global_wg_size, - quantize_block_wise_local_wg_size, - // Inputs and Outputs - {{output, vkapi::kWrite}, - {input, vkapi::kRead}, - {{scale, zero_point}, vkapi::kRead}}, - // Shader param buffers - param_ubos, - // Push Constants - push_constants, - // Specialization Constants - spec_vars, - // Resize Args - {}, - // Resizing Logic - resize_quantize_node)); -} - -void quantize_per_tensor_impl( - ComputeGraph& graph, - const std::vector& args) { - int arg_idx = 0; - const ValueRef input = args[arg_idx++]; - const ValueRef scale = args[arg_idx++]; - const ValueRef zero_point = args[arg_idx++]; - const ValueRef quant_min = args[arg_idx++]; - const ValueRef quant_max = args[arg_idx++]; - const ValueRef dtype = args[arg_idx++]; - const ValueRef output = args[arg_idx++]; - - // Suppress unused variable warning - dtype is inferred from output - (void)dtype; - - // Check tensor types - VK_CHECK_COND(graph.val_is_tensor(input)); - VK_CHECK_COND(graph.val_is_tensor(scale)); - VK_CHECK_COND(graph.val_is_tensor(zero_point)); - VK_CHECK_COND(graph.val_is_tensor(output)); - - // Verify input is a floating point type - VK_CHECK_COND( - graph.dtype_of(input) == vkapi::kDouble || - graph.dtype_of(input) == vkapi::kFloat || - graph.dtype_of(input) == vkapi::kHalf); - - // Get scale and zero point dtypes - vkapi::ScalarType scale_dtype = graph.dtype_of(scale); - vkapi::ScalarType zero_point_dtype = graph.dtype_of(zero_point); - - // Verify supported types for scale (fp32 only for now) - VK_CHECK_COND(scale_dtype == vkapi::kFloat); - - // Verify supported types for zero point (int32, int8, fp32) - VK_CHECK_COND( - zero_point_dtype == vkapi::kInt || zero_point_dtype == vkapi::kChar || - zero_point_dtype == vkapi::kFloat); - - add_quantize_per_tensor_node( - graph, input, scale, zero_point, quant_min, quant_max, output); -} - -void quantize_per_token_impl( - ComputeGraph& graph, - const std::vector& args) { - int arg_idx = 0; - const ValueRef input = args[arg_idx++]; - const ValueRef scale = args[arg_idx++]; - const ValueRef zero_point = args[arg_idx++]; - const ValueRef quant_min = args[arg_idx++]; - const ValueRef quant_max = args[arg_idx++]; - const ValueRef dtype = args[arg_idx++]; - const ValueRef output = args[arg_idx++]; - - // Suppress unused variable warning - dtype is inferred from output - (void)dtype; - - // Check tensor types - VK_CHECK_COND(graph.val_is_tensor(input)); - VK_CHECK_COND(graph.val_is_tensor(scale)); - VK_CHECK_COND(graph.val_is_tensor(zero_point)); - VK_CHECK_COND(graph.val_is_tensor(output)); - - // Verify input is a floating point type - VK_CHECK_COND( - graph.dtype_of(input) == vkapi::kDouble || - graph.dtype_of(input) == vkapi::kFloat || - graph.dtype_of(input) == vkapi::kHalf); - - // Get scale and zero point dtypes - vkapi::ScalarType scale_dtype = graph.dtype_of(scale); - vkapi::ScalarType zero_point_dtype = graph.dtype_of(zero_point); - - // Verify supported types for scale (fp32 only for now) - VK_CHECK_COND(scale_dtype == vkapi::kFloat); - - // Verify supported types for zero point (int32, int8, fp32) - VK_CHECK_COND( - zero_point_dtype == vkapi::kInt || zero_point_dtype == vkapi::kChar || - zero_point_dtype == vkapi::kFloat); - - // Check that scale and zero_point have buffer storage and width packing - VK_CHECK_COND(graph.is_buffer_storage(scale)); - VK_CHECK_COND(graph.packed_dim_of(scale) == WHCN::kWidthDim); - VK_CHECK_COND(graph.is_buffer_storage(zero_point)); - VK_CHECK_COND(graph.packed_dim_of(zero_point) == WHCN::kWidthDim); - - // Check that tensors with texture storage have standard axis map - if (!graph.is_buffer_storage(input)) { - VK_CHECK_COND(graph.has_standard_axis_map(input)); - } - if (!graph.is_buffer_storage(output)) { - VK_CHECK_COND(graph.has_standard_axis_map(output)); - } - - // Calculate number of tokens (product of all dimensions except the last one) - int64_t num_tokens = 1; - const auto input_sizes = graph.sizes_of(input); - for (size_t i = 0; i < input_sizes.size() - 1; i++) { - num_tokens *= input_sizes[i]; - } - - const auto scale_sizes = graph.sizes_of(scale); - const auto zero_point_sizes = graph.sizes_of(zero_point); - - // Calculate total number of elements in scale and zero_point tensors - int64_t scale_numel = 1; - for (size_t i = 0; i < scale_sizes.size(); i++) { - scale_numel *= scale_sizes[i]; - } - - int64_t zero_point_numel = 1; - for (size_t i = 0; i < zero_point_sizes.size(); i++) { - zero_point_numel *= zero_point_sizes[i]; - } - - // Check that the total number of elements matches num_tokens - // This allows for both 1D tensors (size [num_tokens]) and reshaped tensors - // (size [num_tokens, 1]) - VK_CHECK_COND(scale_numel == num_tokens); - VK_CHECK_COND(zero_point_numel == num_tokens); - - add_quantize_per_token_node( - graph, input, scale, zero_point, quant_min, quant_max, output); -} - -void quantize_per_channel_impl( - ComputeGraph& graph, - const std::vector& args) { - int arg_idx = 0; - const ValueRef input = args[arg_idx++]; - const ValueRef scale = args[arg_idx++]; - const ValueRef zero_point = args[arg_idx++]; - const ValueRef axis = args[arg_idx++]; - const ValueRef quant_min = args[arg_idx++]; - const ValueRef quant_max = args[arg_idx++]; - const ValueRef dtype = args[arg_idx++]; - const ValueRef output = args[arg_idx++]; - - // Suppress unused variable warning - dtype is inferred from output - (void)dtype; - - // Check tensor types - VK_CHECK_COND(graph.val_is_tensor(input)); - VK_CHECK_COND(graph.val_is_tensor(scale)); - VK_CHECK_COND(graph.val_is_tensor(zero_point)); - VK_CHECK_COND(graph.val_is_tensor(output)); - - // Verify input is a floating point type - VK_CHECK_COND( - graph.dtype_of(input) == vkapi::kDouble || - graph.dtype_of(input) == vkapi::kFloat || - graph.dtype_of(input) == vkapi::kHalf); - - // Get scale and zero point dtypes - vkapi::ScalarType scale_dtype = graph.dtype_of(scale); - vkapi::ScalarType zero_point_dtype = graph.dtype_of(zero_point); - - // Verify supported types for scale (fp32 only for now) - VK_CHECK_COND(scale_dtype == vkapi::kFloat); - - // Verify supported types for zero point (int32, int8, fp32) - VK_CHECK_COND( - zero_point_dtype == vkapi::kInt || zero_point_dtype == vkapi::kChar || - zero_point_dtype == vkapi::kFloat); - - // Check that scale and zero_point have buffer storage and width packing - VK_CHECK_COND(graph.is_buffer_storage(scale)); - VK_CHECK_COND(graph.packed_dim_of(scale) == WHCN::kWidthDim); - VK_CHECK_COND(graph.is_buffer_storage(zero_point)); - VK_CHECK_COND(graph.packed_dim_of(zero_point) == WHCN::kWidthDim); - - // Check that tensors with texture storage have standard axis map - if (!graph.is_buffer_storage(input)) { - VK_CHECK_COND(graph.has_standard_axis_map(input)); - } - if (!graph.is_buffer_storage(output)) { - VK_CHECK_COND(graph.has_standard_axis_map(output)); - } - - // Normalize axis - int axis_val = static_cast(graph.get_int(axis)); - const auto input_sizes = graph.sizes_of(input); - int64_t ndim = graph.dim_of(input); - if (axis_val < 0) { - axis_val += ndim; - } - - // Verify axis is valid - VK_CHECK_COND(axis_val >= 0 && axis_val < ndim); - - // Get number of channels along the specified axis - int64_t num_channels = input_sizes[axis_val]; - - const auto scale_sizes = graph.sizes_of(scale); - const auto zero_point_sizes = graph.sizes_of(zero_point); - - // Calculate total number of elements in scale and zero_point tensors - int64_t scale_numel = 1; - for (size_t i = 0; i < scale_sizes.size(); i++) { - scale_numel *= scale_sizes[i]; - } - - int64_t zero_point_numel = 1; - for (size_t i = 0; i < zero_point_sizes.size(); i++) { - zero_point_numel *= zero_point_sizes[i]; - } - - // Check that the total number of elements matches num_channels - VK_CHECK_COND(scale_numel == num_channels); - VK_CHECK_COND(zero_point_numel == num_channels); - - add_quantize_per_channel_node( - graph, input, scale, zero_point, axis, quant_min, quant_max, output); -} - -void quantize_affine_impl( - ComputeGraph& graph, - const std::vector& args) { - int arg_idx = 0; - const ValueRef input = args[arg_idx++]; - const ValueRef block_size = args[arg_idx++]; - const ValueRef scale = args[arg_idx++]; - const ValueRef zero_point = args[arg_idx++]; - const ValueRef output_dtype = args[arg_idx++]; - const ValueRef quant_min = args[arg_idx++]; - const ValueRef quant_max = args[arg_idx++]; - const ValueRef output = args[arg_idx++]; - - // Suppress unused variable warnings - (void)output_dtype; - - // Check tensor types - VK_CHECK_COND(graph.val_is_tensor(input)); - VK_CHECK_COND(graph.val_is_tensor(scale)); - VK_CHECK_COND(graph.val_is_tensor(zero_point)); - VK_CHECK_COND(graph.val_is_tensor(output)); - - // Verify input is a floating point type - VK_CHECK_COND( - graph.dtype_of(input) == vkapi::kDouble || - graph.dtype_of(input) == vkapi::kFloat || - graph.dtype_of(input) == vkapi::kHalf); - - // Get scale and zero point dtypes - vkapi::ScalarType scale_dtype = graph.dtype_of(scale); - vkapi::ScalarType zero_point_dtype = graph.dtype_of(zero_point); - - // Verify supported types for scale (fp32 only for now) - VK_CHECK_COND(scale_dtype == vkapi::kFloat); - - // Verify supported types for zero point (int32, int8, fp32) - VK_CHECK_COND( - zero_point_dtype == vkapi::kInt || zero_point_dtype == vkapi::kChar || - zero_point_dtype == vkapi::kFloat); - - // Check that scale and zero_point have buffer storage and width packing - VK_CHECK_COND(graph.is_buffer_storage(scale)); - VK_CHECK_COND(graph.packed_dim_of(scale) == WHCN::kWidthDim); - VK_CHECK_COND(graph.is_buffer_storage(zero_point)); - VK_CHECK_COND(graph.packed_dim_of(zero_point) == WHCN::kWidthDim); - - // Check that tensors with texture storage have standard axis map - if (!graph.is_buffer_storage(input)) { - VK_CHECK_COND(graph.has_standard_axis_map(input)); - } - if (!graph.is_buffer_storage(output)) { - VK_CHECK_COND(graph.has_standard_axis_map(output)); - } - - // Verify block_size is valid (each dimension must divide evenly into input - // size) - const auto input_sizes = graph.sizes_of(input); - const auto block_size_list = graph.get_int_list(block_size); - VK_CHECK_COND(block_size_list->size() == input_sizes.size()); - - for (size_t i = 0; i < input_sizes.size(); i++) { - if ((*block_size_list)[i] > 1) { - VK_CHECK_COND( - input_sizes[i] % (*block_size_list)[i] == 0, - "Input size at dimension ", - i, - " (", - input_sizes[i], - ") must be divisible by block_size at dimension ", - i, - " (", - (*block_size_list)[i], - ")"); - } - } - - add_quantize_block_wise_node( - graph, - input, - block_size, - scale, - zero_point, - quant_min, - quant_max, - output); -} - -REGISTER_OPERATORS { - VK_REGISTER_OP( - quantized_decomposed.quantize_per_tensor.tensor, - quantize_per_tensor_impl); - VK_REGISTER_OP( - quantized_decomposed.quantize_per_token.default, quantize_per_token_impl); - VK_REGISTER_OP( - quantized_decomposed.quantize_per_channel.default, - quantize_per_channel_impl); - - // TorchAO affine quantization operators - VK_REGISTER_OP(torchao.quantize_affine.default, quantize_affine_impl); -} - -} // namespace vkcompute diff --git a/backends/vulkan/test/op_tests/choose_qparams_test.cpp b/backends/vulkan/test/op_tests/choose_qparams_test.cpp deleted file mode 100644 index 3b1094a1e84..00000000000 --- a/backends/vulkan/test/op_tests/choose_qparams_test.cpp +++ /dev/null @@ -1,786 +0,0 @@ -/* - * Copyright (c) Meta Platforms, Inc. and affiliates. - * All rights reserved. - * - * This source code is licensed under the BSD-style license found in the - * LICENSE file in the root directory of this source tree. - */ - -#include - -#include - -#include -#include -#include - -#include -#include - -#include "test_utils.h" - -#include -#include - -namespace torch { -namespace executor { -namespace native { - -// Forward declarations of the functions we're testing -std::tuple choose_qparams_tensor_out( - const Tensor& input, - int64_t quant_min, - int64_t quant_max, - ET_UNUSED double eps, - ScalarType dtype, - Tensor& scale_out, - Tensor& zero_point_out); - -std::tuple choose_qparams_per_token_asymmetric_out( - const Tensor& input, - ScalarType dtype, - Tensor& scale_out, - Tensor& zero_point_out); - -// Wrapper function for choose_qparams_tensor_out without context -Tensor& choose_qparams_tensor_out_no_context( - const Tensor& input, - int64_t quant_min, - int64_t quant_max, - ET_UNUSED double eps, - ScalarType dtype, - Tensor& scale_out, - Tensor& zero_point_out) { - torch::executor::native::choose_qparams_tensor_out( - input, quant_min, quant_max, eps, dtype, scale_out, zero_point_out); - return scale_out; -} - -// Wrapper function for choose_qparams_per_token_asymmetric_out without context -Tensor& choose_qparams_per_token_asymmetric_out_no_context( - const Tensor& input, - ScalarType dtype, - Tensor& scale_out, - Tensor& zero_point_out) { - torch::executor::native::choose_qparams_per_token_asymmetric_out( - input, dtype, scale_out, zero_point_out); - return scale_out; -} - -// ATen wrapper for choose_qparams_tensor -std::tuple choose_qparams_tensor_aten( - const at::Tensor& input, - int64_t quant_min, - int64_t quant_max, - at::ScalarType dtype) { - auto scale_out = at::empty({}, at::device(at::kCPU).dtype(at::kDouble)); - auto zero_point_out = at::empty({}, at::device(at::kCPU).dtype(at::kLong)); - double eps = 1e-7; - - ScalarType et_dtype = at_scalartype_to_et_scalartype(dtype); - - // Use WRAP_TO_ATEN with the wrapper function - WRAP_TO_ATEN(choose_qparams_tensor_out_no_context, 5) - (input, quant_min, quant_max, eps, et_dtype, scale_out, zero_point_out); - - return {scale_out, zero_point_out}; -} - -// ATen wrapper for choose_qparams_per_token_asymmetric -std::tuple choose_qparams_per_token_asymmetric_aten( - const at::Tensor& input, - at::ScalarType dtype) { - // Calculate output sizes for scale and zero_point tensors - std::vector output_sizes; - for (int64_t i = 0; i < input.dim() - 1; i++) { - output_sizes.push_back(input.size(i)); - } - output_sizes.push_back(1); - - auto scale_out = - at::empty(output_sizes, at::device(at::kCPU).dtype(at::kDouble)); - auto zero_point_out = - at::empty(output_sizes, at::device(at::kCPU).dtype(at::kLong)); - - ScalarType et_dtype = at_scalartype_to_et_scalartype(dtype); - - // Use WRAP_TO_ATEN with the wrapper function - WRAP_TO_ATEN(choose_qparams_per_token_asymmetric_out_no_context, 2) - (input, et_dtype, scale_out, zero_point_out); - - return {scale_out, zero_point_out}; -} - -} // namespace native -} // namespace executor -} // namespace torch - -// -// Reference Implementation -// - -/* - * Reference implementation of choose_qparams_tensor - */ -std::tuple choose_qparams_tensor_reference_impl( - const at::Tensor& input, - int64_t quant_min, - int64_t quant_max) { - // Create output tensors - at::Tensor scale_out = at::empty({}, at::device(at::kCPU).dtype(at::kDouble)); - at::Tensor zero_point_out = - at::empty({}, at::device(at::kCPU).dtype(at::kLong)); - - // Find min and max values in the input tensor - float min_val = input.min().item(); - float max_val = input.max().item(); - - // Extend the [min, max] interval to ensure it contains 0 - min_val = std::min(min_val, 0.f); - max_val = std::max(max_val, 0.f); - - // Calculate scale - double scale = - (static_cast(max_val) - min_val) / (quant_max - quant_min); - - // Handle small scale - constexpr float SMALL_SCALE_THRESHOLD = 6.1e-5f; - if (float(scale) == 0.0f || std::isinf(1.0f / float(scale))) { - scale = 0.1; - } - - if (scale < SMALL_SCALE_THRESHOLD) { - float org_scale = scale; - scale = SMALL_SCALE_THRESHOLD; - // Adjust min and max based on new scale - if (min_val == 0.0f) { - max_val = SMALL_SCALE_THRESHOLD * (quant_max - quant_min); - } else if (max_val == 0.0f) { - min_val = -SMALL_SCALE_THRESHOLD * (quant_max - quant_min); - } else { - float amplifier = SMALL_SCALE_THRESHOLD / org_scale; - min_val *= amplifier; - max_val *= amplifier; - } - } - - // Calculate zero point - double zero_point_from_min = quant_min - min_val / static_cast(scale); - double zero_point_from_max = quant_max - max_val / static_cast(scale); - double zero_point_from_min_error = - std::abs(quant_min) - std::abs(min_val / static_cast(scale)); - double zero_point_from_max_error = - std::abs(quant_max) - std::abs(max_val / static_cast(scale)); - double initial_zero_point = - zero_point_from_min_error < zero_point_from_max_error - ? zero_point_from_min - : zero_point_from_max; - - // Nudge zero point to be an integer - int64_t nudged_zero_point = 0; - if (initial_zero_point < quant_min) { - nudged_zero_point = quant_min; - } else if (initial_zero_point > quant_max) { - nudged_zero_point = quant_max; - } else { - nudged_zero_point = std::nearbyint(static_cast(initial_zero_point)); - } - - // Set output values - use item_mutable() for scalar tensors - scale_out.fill_(scale); - zero_point_out.fill_(nudged_zero_point); - - return std::make_tuple(scale_out, zero_point_out); -} - -/* - * Reference implementation of choose_qparams_per_token_asymmetric - */ -std::tuple -choose_qparams_per_token_asymmetric_reference_impl( - const at::Tensor& input, - at::ScalarType dtype) { - // For per-token quantization, we need to compute scale and zero_point for - // each token - int64_t quant_min = -128; - int64_t quant_max = 127; - - // Calculate output sizes - std::vector output_sizes; - for (int64_t i = 0; i < input.dim() - 1; i++) { - output_sizes.push_back(input.size(i)); - } - output_sizes.push_back(1); - - // Create output tensors - at::Tensor scale_out = - at::empty(output_sizes, at::device(at::kCPU).dtype(at::kDouble)); - at::Tensor zero_point_out = - at::empty(output_sizes, at::device(at::kCPU).dtype(at::kLong)); - - // Calculate number of tokens - int64_t num_tokens = 1; - for (int64_t i = 0; i < input.dim() - 1; i++) { - num_tokens *= input.size(i); - } - - // Reshape input to [num_tokens, last_dim] - at::Tensor reshaped_input = input.reshape({num_tokens, input.size(-1)}); - - // Process each token - for (int64_t token_idx = 0; token_idx < num_tokens; token_idx++) { - at::Tensor token = reshaped_input[token_idx]; - - // Find min and max values for this token - float min_val = token.min().item(); - float max_val = token.max().item(); - - // Extend the [min, max] interval to ensure it contains 0 - min_val = std::min(min_val, 0.f); - max_val = std::max(max_val, 0.f); - - // Calculate scale - double scale = - (static_cast(max_val) - min_val) / (quant_max - quant_min); - - // Handle small scale - constexpr float SMALL_SCALE_THRESHOLD = 6.1e-5f; - if (float(scale) == 0.0f || std::isinf(1.0f / float(scale))) { - scale = 0.1; - } - - if (scale < SMALL_SCALE_THRESHOLD) { - float org_scale = scale; - scale = SMALL_SCALE_THRESHOLD; - // Adjust min and max based on new scale - if (min_val == 0.0f) { - max_val = SMALL_SCALE_THRESHOLD * (quant_max - quant_min); - } else if (max_val == 0.0f) { - min_val = -SMALL_SCALE_THRESHOLD * (quant_max - quant_min); - } else { - float amplifier = SMALL_SCALE_THRESHOLD / org_scale; - min_val *= amplifier; - max_val *= amplifier; - } - } - - // Calculate zero point - double zero_point_from_min = - quant_min - min_val / static_cast(scale); - double zero_point_from_max = - quant_max - max_val / static_cast(scale); - double zero_point_from_min_error = - std::abs(quant_min) - std::abs(min_val / static_cast(scale)); - double zero_point_from_max_error = - std::abs(quant_max) - std::abs(max_val / static_cast(scale)); - double initial_zero_point = - zero_point_from_min_error < zero_point_from_max_error - ? zero_point_from_min - : zero_point_from_max; - - // Nudge zero point to be an integer - int64_t nudged_zero_point = 0; - if (initial_zero_point < quant_min) { - nudged_zero_point = quant_min; - } else if (initial_zero_point > quant_max) { - nudged_zero_point = quant_max; - } else { - nudged_zero_point = - std::nearbyint(static_cast(initial_zero_point)); - } - - // Set output values for this token - use index_put_ for safety - scale_out.view({num_tokens, 1}).index_put_({token_idx, 0}, scale); - zero_point_out.view({num_tokens, 1}) - .index_put_({token_idx, 0}, nudged_zero_point); - } - - return std::make_tuple(scale_out, zero_point_out); -} - -// Forward declaration of implementation functions -void test_vulkan_choose_qparams_tensor_impl( - const std::vector& input_sizes, - int64_t quant_min, - int64_t quant_max, - at::ScalarType dtype, - const vkcompute::utils::StorageType in_storage, - const vkcompute::utils::StorageType out_storage); - -void test_vulkan_choose_qparams_per_token_asymmetric_impl( - const std::vector& input_sizes, - at::ScalarType dtype, - const vkcompute::utils::StorageType in_storage, - const vkcompute::utils::StorageType out_storage); - -// Wrapper function to test both buffer and texture storage types -void test_vulkan_choose_qparams_tensor( - const std::vector& input_sizes, - int64_t quant_min, - int64_t quant_max, - at::ScalarType dtype) { - // Test with buffer storage - test_vulkan_choose_qparams_tensor_impl( - input_sizes, - quant_min, - quant_max, - dtype, - vkcompute::utils::kBuffer, - vkcompute::utils::kBuffer); - - // Test with texture storage - test_vulkan_choose_qparams_tensor_impl( - input_sizes, - quant_min, - quant_max, - dtype, - vkcompute::utils::kTexture3D, - vkcompute::utils::kTexture3D); -} - -// Wrapper function to test both buffer and texture storage types -void test_vulkan_choose_qparams_per_token_asymmetric( - const std::vector& input_sizes, - at::ScalarType dtype) { - // Test with buffer storage - test_vulkan_choose_qparams_per_token_asymmetric_impl( - input_sizes, dtype, vkcompute::utils::kBuffer, vkcompute::utils::kBuffer); - - // Test with texture storage - test_vulkan_choose_qparams_per_token_asymmetric_impl( - input_sizes, - dtype, - vkcompute::utils::kTexture3D, - vkcompute::utils::kTexture3D); -} - -void test_reference_choose_qparams_tensor( - const std::vector& input_sizes, - int64_t quant_min, - int64_t quant_max, - at::ScalarType dtype) { - std::vector input_sizes_int64( - input_sizes.begin(), input_sizes.end()); - at::Tensor input = - at::rand(input_sizes_int64, at::device(at::kCPU).dtype(at::kFloat)); - - // Get reference output - auto [reference_scale, reference_zero_point] = - choose_qparams_tensor_reference_impl(input, quant_min, quant_max); - - // Get implementation output - auto [impl_scale, impl_zero_point] = - torch::executor::native::choose_qparams_tensor_aten( - input, quant_min, quant_max, dtype); - - // Compare outputs - const bool scale_correct = at::allclose(reference_scale, impl_scale); - const bool zero_point_correct = - at::equal(reference_zero_point, impl_zero_point); - - if (!scale_correct || !zero_point_correct) { - std::cout << "\n" - << "Failed with parameters: " << std::endl; - std::cout << " quant_min: " << quant_min << std::endl; - std::cout << " quant_max: " << quant_max << std::endl; - - std::cout << "input:" << std::endl; - std::cout << input << std::endl; - std::cout << "reference scale:" << std::endl; - std::cout << reference_scale << std::endl; - std::cout << "implementation scale:" << std::endl; - std::cout << impl_scale << std::endl; - std::cout << "reference zero_point:" << std::endl; - std::cout << reference_zero_point << std::endl; - std::cout << "implementation zero_point:" << std::endl; - std::cout << impl_zero_point << std::endl; - } - - ASSERT_TRUE(scale_correct && zero_point_correct); -} - -void test_vulkan_choose_qparams_tensor_impl( - const std::vector& input_sizes, - int64_t quant_min, - int64_t quant_max, - at::ScalarType dtype, - const vkcompute::utils::StorageType in_storage, - const vkcompute::utils::StorageType out_storage) { - std::vector input_sizes_int64( - input_sizes.begin(), input_sizes.end()); - at::Tensor input = - at::rand(input_sizes_int64, at::device(at::kCPU).dtype(at::kFloat)); - - // Get reference output - auto [reference_scale, reference_zero_point] = - torch::executor::native::choose_qparams_tensor_aten( - input, quant_min, quant_max, dtype); - - // Build Vulkan choose_qparams_tensor graph - using namespace vkcompute; - - GraphConfig config; - config.set_storage_type_override(in_storage); - ComputeGraph graph(config); - - IOValueRef r_input = graph.add_input_tensor( - input.sizes().vec(), from_at_scalartype(input.scalar_type()), in_storage); - - const ValueRef r_quant_min = graph.add_scalar(quant_min); - const ValueRef r_quant_max = graph.add_scalar(quant_max); - - // Output tensors - const ValueRef r_scale = graph.add_tensor({}, vkapi::kFloat, out_storage); - const ValueRef r_zero_point = graph.add_tensor({}, vkapi::kInt, out_storage); - - // Create output tuple - const ValueRef r_out_tuple = graph.add_value_list({r_scale, r_zero_point}); - - // Add eps and dtype parameters to match ATen signature - const ValueRef r_eps = graph.add_scalar(6.1e-5); - const ValueRef r_dtype = - graph.add_scalar(static_cast(dtype)); - - VK_GET_OP_FN("quantized_decomposed.choose_qparams.tensor") - (graph, - { - r_input.value, - r_quant_min, - r_quant_max, - r_eps, - r_dtype, - r_out_tuple, - }); - - ValueRef staging_scale = graph.set_output_tensor(r_scale); - ValueRef staging_zero_point = graph.set_output_tensor(r_zero_point); - - graph.prepare(); - - graph.prepack(); - - // Run Vulkan choose_qparams_tensor - graph.copy_into_staging( - r_input.staging, input.const_data_ptr(), input.numel()); - - graph.execute(); - - // Create output tensors to hold the results - use types that match GPU output - at::Tensor vk_scale = - at::empty({}, at::device(at::kCPU).dtype(at::kFloat)).contiguous(); - at::Tensor vk_zero_point = - at::empty({}, at::device(at::kCPU).dtype(at::kInt)).contiguous(); - - // Copy results from GPU to CPU - graph.copy_from_staging( - staging_scale, vk_scale.mutable_data_ptr(), vk_scale.numel()); - graph.copy_from_staging( - staging_zero_point, - vk_zero_point.mutable_data_ptr(), - vk_zero_point.numel()); - - // Convert reference values to match Vulkan output types for comparison - at::Tensor reference_scale_float = reference_scale.to(at::kFloat); - at::Tensor reference_zero_point_int = reference_zero_point.to(at::kInt); - - // Compare outputs - const bool scale_correct = at::allclose(reference_scale_float, vk_scale); - const bool zero_point_correct = - at::equal(reference_zero_point_int, vk_zero_point); - - if (!scale_correct || !zero_point_correct) { - std::cout << "\n" - << "Failed with parameters: " << std::endl; - std::cout << " quant_min: " << quant_min << std::endl; - std::cout << " quant_max: " << quant_max << std::endl; - std::cout << " storage type: " - << (in_storage == vkcompute::utils::kBuffer ? "buffer" - : "texture") - << std::endl; - - // make sure that there arent a ton of elements in the input tensor - if (input.numel() < 100) { - std::cout << "input:" << std::endl; - std::cout << input << "\n" << std::endl; - std::cout << "reference scale:" << std::endl; - std::cout << reference_scale << std::endl; - std::cout << "vulkan scale:" << std::endl; - std::cout << vk_scale << "\n" << std::endl; - std::cout << "reference zero_point:" << std::endl; - std::cout << reference_zero_point << std::endl; - std::cout << "vulkan zero_point:" << std::endl; - std::cout << vk_zero_point << std::endl; - } - } - - ASSERT_TRUE(scale_correct && zero_point_correct); -} - -TEST(VulkanChooseQparamsTest, test_reference_choose_qparams_tensor_int8) { - test_reference_choose_qparams_tensor( - {2, 3, 4}, // input sizes - -128, // quant_min - 127, // quant_max - at::kChar); -} - -TEST(VulkanChooseQparamsTest, test_vulkan_choose_qparams_tensor_uint8_4D) { - if (!vkcompute::api::context() - ->adapter_ptr() - ->has_full_int8_buffers_support()) { - GTEST_SKIP(); - } - test_vulkan_choose_qparams_tensor( - {5, 3, 2, 4}, // input sizes - 0, // quant_min - 255, // quant_max - at::kByte); -} - -TEST(VulkanChooseQparamsTest, test_vulkan_choose_qparams_tensor_int8_2D) { - if (!vkcompute::api::context() - ->adapter_ptr() - ->has_full_int8_buffers_support()) { - GTEST_SKIP(); - } - test_vulkan_choose_qparams_tensor( - {5, 5}, // input sizes - -128, // quant_min - 127, // quant_max - at::kChar); -} - -TEST(VulkanChooseQparamsTest, test_vulkan_choose_qparams_tensor_int8_3D) { - if (!vkcompute::api::context() - ->adapter_ptr() - ->has_full_int8_buffers_support()) { - GTEST_SKIP(); - } - test_vulkan_choose_qparams_tensor( - {12, 8, 2}, // input sizes - -128, // quant_min - 127, // quant_max - at::kChar); -} - -TEST(VulkanChooseQparamsTest, test_vulkan_choose_qparams_tensor_int8_4D) { - if (!vkcompute::api::context() - ->adapter_ptr() - ->has_full_int8_buffers_support()) { - GTEST_SKIP(); - } - test_vulkan_choose_qparams_tensor( - {10, 10, 6, 4}, // input sizes - -128, // quant_min - 127, // quant_max - at::kChar); -} - -void test_reference_choose_qparams_per_token_asymmetric( - const std::vector& input_sizes, - at::ScalarType dtype) { - std::vector input_sizes_int64( - input_sizes.begin(), input_sizes.end()); - at::Tensor input = - at::rand(input_sizes_int64, at::device(at::kCPU).dtype(at::kFloat)); - - // Get reference output - auto [reference_scale, reference_zero_point] = - choose_qparams_per_token_asymmetric_reference_impl(input, dtype); - - // Get implementation output - auto [impl_scale, impl_zero_point] = - torch::executor::native::choose_qparams_per_token_asymmetric_aten( - input, dtype); - - // Compare outputs - const bool scale_correct = at::allclose(reference_scale, impl_scale); - const bool zero_point_correct = - at::equal(reference_zero_point, impl_zero_point); - - if (!scale_correct || !zero_point_correct) { - std::cout << "\n" - << "Failed with parameters: " << std::endl; - - std::cout << "input:" << std::endl; - std::cout << input << std::endl; - std::cout << "reference scale:" << std::endl; - std::cout << reference_scale << std::endl; - std::cout << "implementation scale:" << std::endl; - std::cout << impl_scale << std::endl; - std::cout << "reference zero_point:" << std::endl; - std::cout << reference_zero_point << std::endl; - std::cout << "implementation zero_point:" << std::endl; - std::cout << impl_zero_point << std::endl; - } - - ASSERT_TRUE(scale_correct && zero_point_correct); -} - -void test_vulkan_choose_qparams_per_token_asymmetric_impl( - const std::vector& input_sizes, - at::ScalarType dtype, - const vkcompute::utils::StorageType in_storage, - const vkcompute::utils::StorageType out_storage) { - std::vector input_sizes_int64( - input_sizes.begin(), input_sizes.end()); - at::Tensor input = - at::rand(input_sizes_int64, at::device(at::kCPU).dtype(at::kFloat)); - - // Calculate output sizes - std::vector output_sizes; - for (int64_t i = 0; i < input.dim() - 1; i++) { - output_sizes.push_back(input.size(i)); - } - output_sizes.push_back(1); - - // Get reference output - auto [reference_scale, reference_zero_point] = - torch::executor::native::choose_qparams_per_token_asymmetric_aten( - input, dtype); - - // Build Vulkan choose_qparams_per_token_asymmetric graph - using namespace vkcompute; - - GraphConfig config; - config.set_storage_type_override(in_storage); - ComputeGraph graph(config); - - IOValueRef r_input = graph.add_input_tensor( - input.sizes().vec(), from_at_scalartype(input.scalar_type()), in_storage); - - // Output tensors - const ValueRef r_scale = - graph.add_tensor(output_sizes, vkapi::kFloat, out_storage); - const ValueRef r_zero_point = - graph.add_tensor(output_sizes, vkapi::kInt, out_storage); - - // Create output tuple - const ValueRef r_out_tuple = graph.add_value_list({r_scale, r_zero_point}); - - // Add dtype parameter to match ATen signature - const ValueRef r_dtype = - graph.add_scalar(static_cast(dtype)); - - VK_GET_OP_FN( - "quantized_decomposed.choose_qparams_per_token_asymmetric.default") - (graph, - { - r_input.value, - r_dtype, - r_out_tuple, - }); - - ValueRef staging_scale = graph.set_output_tensor(r_scale); - ValueRef staging_zero_point = graph.set_output_tensor(r_zero_point); - - graph.prepare(); - - graph.prepack(); - - // Run Vulkan choose_qparams_per_token_asymmetric - graph.copy_into_staging( - r_input.staging, input.const_data_ptr(), input.numel()); - - graph.execute(); - - // Create output tensors to hold the results - use types that match GPU output - at::Tensor vk_scale = - at::empty(output_sizes, at::device(at::kCPU).dtype(at::kFloat)) - .contiguous(); - at::Tensor vk_zero_point = - at::empty(output_sizes, at::device(at::kCPU).dtype(at::kInt)) - .contiguous(); - - // Copy results from GPU to CPU - graph.copy_from_staging( - staging_scale, vk_scale.mutable_data_ptr(), vk_scale.numel()); - graph.copy_from_staging( - staging_zero_point, - vk_zero_point.mutable_data_ptr(), - vk_zero_point.numel()); - - // Convert reference values to match Vulkan output types for comparison - at::Tensor reference_scale_float = reference_scale.to(at::kFloat); - at::Tensor reference_zero_point_int = reference_zero_point.to(at::kInt); - - // Compare outputs - const bool scale_correct = at::allclose(reference_scale_float, vk_scale); - const bool zero_point_correct = - at::equal(reference_zero_point_int, vk_zero_point); - if (!scale_correct || !zero_point_correct) { - std::cout << "\n" - << "Failed with parameters: " << std::endl; - std::cout << " storage type: " - << (in_storage == vkcompute::utils::kBuffer ? "buffer" - : "texture") - << std::endl; - - if (input.numel() < 100) { - std::cout << "input:" << std::endl; - std::cout << input << "\n" << std::endl; - std::cout << "reference scale:" << std::endl; - std::cout << reference_scale << std::endl; - std::cout << "vulkan scale:" << std::endl; - std::cout << vk_scale << "\n" << std::endl; - std::cout << "reference zero_point:" << std::endl; - std::cout << reference_zero_point << std::endl; - std::cout << "vulkan zero_point:" << std::endl; - std::cout << vk_zero_point << std::endl; - } - } - - ASSERT_TRUE(scale_correct && zero_point_correct); -} - -TEST( - VulkanChooseQparamsTest, - test_reference_choose_qparams_per_token_asymmetric_int8) { - test_reference_choose_qparams_per_token_asymmetric( - {2, 3, 4}, // input sizes (2*3=6 tokens) - at::kChar); -} - -TEST( - VulkanChooseQparamsTest, - test_vulkan_choose_qparams_per_token_asymmetric_int8_1D) { - if (!vkcompute::api::context() - ->adapter_ptr() - ->has_full_int8_buffers_support()) { - GTEST_SKIP(); - } - test_vulkan_choose_qparams_per_token_asymmetric({7}, at::kChar); -} - -TEST( - VulkanChooseQparamsTest, - test_vulkan_choose_qparams_per_token_asymmetric_int8_2D) { - if (!vkcompute::api::context() - ->adapter_ptr() - ->has_full_int8_buffers_support()) { - GTEST_SKIP(); - } - test_vulkan_choose_qparams_per_token_asymmetric({2, 2}, at::kChar); -} - -TEST( - VulkanChooseQparamsTest, - test_vulkan_choose_qparams_per_token_asymmetric_int8_3D) { - if (!vkcompute::api::context() - ->adapter_ptr() - ->has_full_int8_buffers_support()) { - GTEST_SKIP(); - } - test_vulkan_choose_qparams_per_token_asymmetric({3, 6, 4}, at::kChar); -} - -TEST( - VulkanChooseQparamsTest, - test_vulkan_choose_qparams_per_token_asymmetric_int8_4D) { - if (!vkcompute::api::context() - ->adapter_ptr() - ->has_full_int8_buffers_support()) { - GTEST_SKIP(); - } - test_vulkan_choose_qparams_per_token_asymmetric({128, 2, 16, 3}, at::kChar); -} diff --git a/backends/vulkan/test/op_tests/dequantize_test.cpp b/backends/vulkan/test/op_tests/dequantize_test.cpp deleted file mode 100644 index 9fca2c632d3..00000000000 --- a/backends/vulkan/test/op_tests/dequantize_test.cpp +++ /dev/null @@ -1,2492 +0,0 @@ -/* - * Copyright (c) Meta Platforms, Inc. and affiliates. - * All rights reserved. - * - * This source code is licensed under the BSD-style license found in the - * LICENSE file in the root directory of this source tree. - */ - -#include - -#include - -#include -#include -#include - -#include -#include - -#include "test_utils.h" - -#include -#include -#include -#include - -namespace torch { -namespace executor { -namespace native { - -// Forward declarations of the functions we're testing -Tensor& dequantize_per_tensor_out( - const Tensor& input, - double scale, - int64_t zero_point, - int64_t quant_min, - int64_t quant_max, - ScalarType dtype, - executorch::aten::optional out_dtype, - Tensor& out); - -Tensor& dequantize_per_token_out( - const Tensor& input, - const Tensor& scale, - const Tensor& zero_points, - int64_t quant_min, - int64_t quant_max, - ScalarType dtype, - ScalarType out_dtype, - Tensor& out); - -Tensor& dequantize_per_channel_out( - const Tensor& input, - const Tensor& scale, - const std::optional& zero_points, - int64_t axis, - int64_t quant_min, - int64_t quant_max, - ScalarType dtype, - executorch::aten::optional out_dtype, - Tensor& out); - -Tensor& dequantize_per_tensor_tensor_args_out( - const Tensor& input, - const Tensor& scale, - const Tensor& zero_point, - int64_t quant_min, - int64_t quant_max, - ScalarType dtype, - executorch::aten::optional out_dtype, - Tensor& out); - -// Wrapper function for dequantize_per_tensor_out without context -Tensor& dequantize_per_tensor_out_no_context( - const Tensor& input, - double scale, - int64_t zero_point, - int64_t quant_min, - int64_t quant_max, - ScalarType dtype, - executorch::aten::optional out_dtype, - Tensor& out) { - return torch::executor::native::dequantize_per_tensor_out( - input, scale, zero_point, quant_min, quant_max, dtype, out_dtype, out); -} - -// Wrapper function for dequantize_per_token_out without context -Tensor& dequantize_per_token_out_no_context( - const Tensor& input, - const Tensor& scale, - const Tensor& zero_points, - int64_t quant_min, - int64_t quant_max, - ScalarType dtype, - ScalarType out_dtype, - Tensor& out) { - return torch::executor::native::dequantize_per_token_out( - input, scale, zero_points, quant_min, quant_max, dtype, out_dtype, out); -} - -// Wrapper function for dequantize_per_channel_out without context -Tensor& dequantize_per_channel_out_no_context( - const Tensor& input, - const Tensor& scale, - const std::optional& zero_points, - int64_t axis, - int64_t quant_min, - int64_t quant_max, - ScalarType dtype, - executorch::aten::optional out_dtype, - Tensor& out) { - return torch::executor::native::dequantize_per_channel_out( - input, - scale, - zero_points, - axis, - quant_min, - quant_max, - dtype, - out_dtype, - out); -} - -// Wrapper function for dequantize_per_tensor_tensor_args_out without context -Tensor& dequantize_per_tensor_tensor_args_out_no_context( - const Tensor& input, - const Tensor& scale, - const Tensor& zero_point, - int64_t quant_min, - int64_t quant_max, - ScalarType dtype, - executorch::aten::optional out_dtype, - Tensor& out) { - return torch::executor::native::dequantize_per_tensor_tensor_args_out( - input, scale, zero_point, quant_min, quant_max, dtype, out_dtype, out); -} - -// ATen wrapper for dequantize_per_tensor -at::Tensor dequantize_per_tensor_aten( - const at::Tensor& input, - double scale, - int64_t zero_point, - int64_t quant_min, - int64_t quant_max, - at::ScalarType dtype, - at::ScalarType out_dtype) { - auto out = at::empty_like(input, out_dtype); - // Convert at::ScalarType to executorch::ScalarType - ScalarType et_dtype = at_scalartype_to_et_scalartype(dtype); - ScalarType et_out_dtype = at_scalartype_to_et_scalartype(out_dtype); - - executorch::aten::optional opt_et_out_dtype(et_out_dtype); - - WRAP_TO_ATEN(dequantize_per_tensor_out_no_context, 7) - (input, - scale, - zero_point, - quant_min, - quant_max, - et_dtype, - opt_et_out_dtype, - out); - return out; -} - -// ATen wrapper for dequantize_per_token -at::Tensor dequantize_per_token_aten( - const at::Tensor& input, - const at::Tensor& scale, - const at::Tensor& zero_points, - int64_t quant_min, - int64_t quant_max, - at::ScalarType dtype, - at::ScalarType out_dtype) { - auto out = at::empty_like(input, out_dtype); - // Convert at::ScalarType to executorch::ScalarType - ScalarType et_dtype = at_scalartype_to_et_scalartype(dtype); - ScalarType et_out_dtype = at_scalartype_to_et_scalartype(out_dtype); - - WRAP_TO_ATEN(dequantize_per_token_out_no_context, 7) - (input, - scale, - zero_points, - quant_min, - quant_max, - et_dtype, - et_out_dtype, - out); - return out; -} - -// ATen wrapper for dequantize_per_channel -at::Tensor dequantize_per_channel_aten( - const at::Tensor& input, - const at::Tensor& scale, - const std::optional& zero_points, - int64_t axis, - int64_t quant_min, - int64_t quant_max, - at::ScalarType dtype, - at::ScalarType out_dtype) { - auto out = at::empty_like(input, out_dtype); - // Convert at::ScalarType to executorch::ScalarType - ScalarType et_dtype = at_scalartype_to_et_scalartype(dtype); - ScalarType et_out_dtype = at_scalartype_to_et_scalartype(out_dtype); - - executorch::aten::optional opt_et_out_dtype(et_out_dtype); - - WRAP_TO_ATEN(dequantize_per_channel_out_no_context, 8) - (input, - scale, - zero_points, - axis, - quant_min, - quant_max, - et_dtype, - opt_et_out_dtype, - out); - return out; -} - -// ATen wrapper for dequantize_per_tensor with tensor args -at::Tensor dequantize_per_tensor_tensor_args_aten( - const at::Tensor& input, - const at::Tensor& scale, - const at::Tensor& zero_point, - int64_t quant_min, - int64_t quant_max, - at::ScalarType dtype, - at::ScalarType out_dtype) { - auto out = at::empty_like(input, out_dtype); - // Convert at::ScalarType to executorch::ScalarType - ScalarType et_dtype = at_scalartype_to_et_scalartype(dtype); - ScalarType et_out_dtype = at_scalartype_to_et_scalartype(out_dtype); - - executorch::aten::optional opt_et_out_dtype(et_out_dtype); - - WRAP_TO_ATEN(dequantize_per_tensor_tensor_args_out_no_context, 7) - (input, - scale, - zero_point, - quant_min, - quant_max, - et_dtype, - opt_et_out_dtype, - out); - return out; -} - -} // namespace native -} // namespace executor -} // namespace torch - -void check_dequantize_args( - int64_t quant_min, - int64_t quant_max, - c10::ScalarType in_dtype, - c10::ScalarType out_dtype) { - using namespace vkcompute; - - // Check that quant_min <= quant_max - VK_CHECK_COND( - quant_min <= quant_max, - "quant_min must be <= quant_max, got quant_min: ", - quant_min, - " quant_max: ", - quant_max); - - // Check that input dtype is a quantized type - switch (in_dtype) { - case c10::kByte: - case c10::kChar: - case c10::kShort: - case c10::kInt: - case c10::kLong: - break; - default: - VK_THROW( - "Unsupported input dtype: ", - scalar_type_name(in_dtype), - " (", - static_cast(in_dtype), - ")"); - } - - // Check that output dtype is a floating point type - switch (out_dtype) { - case c10::kHalf: - case c10::kFloat: - case c10::kDouble: - break; - default: - VK_THROW( - "Unsupported output dtype: ", - scalar_type_name(out_dtype), - " (", - static_cast(out_dtype), - ")"); - } -} - -/** - * Helper function to validate dequantize_per_channel arguments - * Similar to the validation in quantize_test.cpp - */ -void check_dequantize_per_channel_args( - const std::vector& input_sizes, - const std::vector& scales, - const std::vector& zero_points, - int64_t axis) { - // Normalize axis - int64_t normalized_axis = axis; - if (normalized_axis < 0) { - normalized_axis += input_sizes.size(); - } - - ASSERT_GE(normalized_axis, 0) - << "axis " << axis << " is not legal, normalized axis " << normalized_axis - << " should be >= 0"; - - ASSERT_LT(normalized_axis, static_cast(input_sizes.size())) - << "axis " << axis << " is not legal, normalized axis " << normalized_axis - << " should be < input.dim() " << input_sizes.size(); - - int64_t num_channels = input_sizes[normalized_axis]; - - ASSERT_EQ(num_channels, static_cast(scales.size())) - << "Expected scales.size() to match input.size(axis) (" << num_channels - << "), but got " << scales.size(); - - ASSERT_EQ(num_channels, static_cast(zero_points.size())) - << "Expected zero_points.size() to match input.size(axis) (" - << num_channels << "), but got " << zero_points.size(); -} - -// -// Reference Implementation -// - -/* - * Reference implementation of dequantize_per_tensor - */ -at::Tensor dequantize_per_tensor_reference_impl( - const at::Tensor& input, - double scale, - int64_t zero_point, - int64_t quant_min, - int64_t quant_max, - at::ScalarType dtype, - at::ScalarType out_dtype) { - // Create output tensor with the target dtype - at::Tensor out = at::empty_like(input, out_dtype); - - // Dequantize the input tensor - at::Tensor flat_input = input.flatten(); - at::Tensor flat_out = out.flatten(); - - // Store casted values to avoid repeated casting - const int32_t zero_point_int32 = static_cast(zero_point); - const float scale_float = static_cast(scale); - - for (int i = 0; i < flat_input.numel(); i++) { - double dequantized_value = 0.0; - - // Extract quantized value and dequantize based on input dtype - // Following the CPU implementation pattern: (input - zero_point) * scale - if (dtype == at::kByte) { - uint8_t qvalue = flat_input[i].item(); - dequantized_value = (qvalue - zero_point_int32) * scale_float; - } else if (dtype == at::kChar) { - int8_t qvalue = flat_input[i].item(); - dequantized_value = (qvalue - zero_point_int32) * scale_float; - } else if (dtype == at::kShort) { - int16_t qvalue = flat_input[i].item(); - dequantized_value = (qvalue - zero_point_int32) * scale_float; - } else if (dtype == at::kInt) { - int32_t qvalue = flat_input[i].item(); - dequantized_value = (qvalue - zero_point_int32) * scale_float; - } else if (dtype == at::kLong) { - int64_t qvalue = flat_input[i].item(); - dequantized_value = (qvalue - zero_point_int32) * scale_float; - } - - // Store result based on output dtype - if (out_dtype == at::kFloat) { - flat_out[i] = static_cast(dequantized_value); - } else if (out_dtype == at::kDouble) { - flat_out[i] = dequantized_value; - } else if (out_dtype == at::kHalf) { - flat_out[i] = static_cast(dequantized_value); - } - } - - return out.reshape(input.sizes()); -} - -/* - * Reference implementation of dequantize_per_token - */ -at::Tensor dequantize_per_token_reference_impl( - const at::Tensor& input, - const at::Tensor& scale, - const at::Tensor& zero_point, - int64_t quant_min, - int64_t quant_max, - at::ScalarType dtype, - at::ScalarType out_dtype) { - // Create output tensor with the target dtype - at::Tensor out = at::empty_like(input, out_dtype); - - // Calculate number of tokens - int num_tokens = 1; - for (int i = 0; i < input.dim() - 1; i++) { - num_tokens *= input.size(i); - } - - // Verify that the number of tokens matches the size of scale and zero_point - // tensors - assert(num_tokens == scale.numel()); - assert(num_tokens == zero_point.numel()); - - // Reshape input to [num_tokens, last_dim] - at::Tensor reshaped_input = input.reshape({num_tokens, input.size(-1)}); - at::Tensor reshaped_out = out.reshape({num_tokens, input.size(-1)}); - - // Dequantize each token separately - for (int token_idx = 0; token_idx < num_tokens; token_idx++) { - // Get scale and zero_point for this token - float token_scale = scale[token_idx].item(); - int64_t token_zero_point = zero_point[token_idx].item(); - - // Store casted values to avoid repeated casting - const int32_t token_zero_point_int32 = - static_cast(token_zero_point); - - // Dequantize the token - for (int i = 0; i < input.size(-1); i++) { - double dequantized_value = 0.0; - - // Extract quantized value and dequantize based on input dtype - // Following the CPU implementation pattern: (input - zero_point) * scale - if (dtype == at::kByte) { - uint8_t qvalue = reshaped_input[token_idx][i].item(); - dequantized_value = (qvalue - token_zero_point_int32) * token_scale; - } else if (dtype == at::kChar) { - int8_t qvalue = reshaped_input[token_idx][i].item(); - dequantized_value = (qvalue - token_zero_point_int32) * token_scale; - } else if (dtype == at::kShort) { - int16_t qvalue = reshaped_input[token_idx][i].item(); - dequantized_value = (qvalue - token_zero_point_int32) * token_scale; - } else if (dtype == at::kInt) { - int32_t qvalue = reshaped_input[token_idx][i].item(); - dequantized_value = (qvalue - token_zero_point_int32) * token_scale; - } else if (dtype == at::kLong) { - int64_t qvalue = reshaped_input[token_idx][i].item(); - dequantized_value = (qvalue - token_zero_point_int32) * token_scale; - } else { - throw std::runtime_error("Unsupported input dtype"); - } - - // Store result based on output dtype - if (out_dtype == at::kFloat) { - reshaped_out[token_idx][i] = static_cast(dequantized_value); - } else if (out_dtype == at::kDouble) { - reshaped_out[token_idx][i] = dequantized_value; - } else if (out_dtype == at::kHalf) { - reshaped_out[token_idx][i] = static_cast(dequantized_value); - } - } - } - - return out; -} - -/* - * Reference implementation of dequantize_per_channel - */ -at::Tensor dequantize_per_channel_reference_impl( - const at::Tensor& input, - const at::Tensor& scale, - const std::optional& zero_point, - int64_t axis, - int64_t quant_min, - int64_t quant_max, - at::ScalarType dtype, - at::ScalarType out_dtype) { - // Normalize axis to handle negative values - int64_t normalized_axis = axis; - if (normalized_axis < 0) { - normalized_axis += input.dim(); - } - - // Create output tensor with the same shape as input but with target dtype - at::Tensor output = at::empty_like(input, out_dtype); - - // Get the number of channels along the quantization axis - int64_t num_channels = input.size(normalized_axis); - - // Calculate strides for efficient indexing - std::vector input_strides; - std::vector input_sizes; - for (int64_t i = 0; i < input.dim(); i++) { - input_sizes.push_back(input.size(i)); - input_strides.push_back(input.stride(i)); - } - - // Get data pointers - const double* scale_data = scale.const_data_ptr(); - const int64_t* zero_point_data = nullptr; - if (zero_point.has_value()) { - zero_point_data = zero_point.value().const_data_ptr(); - } - - // Iterate through all elements in the tensor - int64_t total_elements = input.numel(); - - // Helper lambda to convert flat index to multi-dimensional coordinates - auto flat_to_coords = [&](int64_t flat_idx, std::vector& coords) { - int64_t remaining = flat_idx; - for (int64_t dim = input.dim() - 1; dim >= 0; dim--) { - coords[dim] = remaining % input_sizes[dim]; - remaining /= input_sizes[dim]; - } - }; - - // Process each element - std::vector coords(input.dim()); - for (int64_t flat_idx = 0; flat_idx < total_elements; flat_idx++) { - // Convert flat index to coordinates - flat_to_coords(flat_idx, coords); - - // Get the channel index for this element - int64_t channel_idx = coords[normalized_axis]; - - // Get the quantization parameters for this channel - double channel_scale = scale_data[channel_idx]; - int64_t channel_zero_point = 0; - if (zero_point_data != nullptr) { - channel_zero_point = zero_point_data[channel_idx]; - } - - // Store casted values to avoid repeated casting - const int32_t channel_zero_point_int32 = - static_cast(channel_zero_point); - const float channel_scale_float = static_cast(channel_scale); - - // Get the input value and dequantize - double dequantized_value = 0.0; - - // Extract quantized value and dequantize based on input dtype - // Following the CPU implementation pattern: (input - zero_point) * scale - if (dtype == at::kByte) { - uint8_t qvalue = input.flatten()[flat_idx].item(); - dequantized_value = - (qvalue - channel_zero_point_int32) * channel_scale_float; - } else if (dtype == at::kChar) { - int8_t qvalue = input.flatten()[flat_idx].item(); - dequantized_value = - (qvalue - channel_zero_point_int32) * channel_scale_float; - } else if (dtype == at::kShort) { - int16_t qvalue = input.flatten()[flat_idx].item(); - dequantized_value = - (qvalue - channel_zero_point_int32) * channel_scale_float; - } else if (dtype == at::kInt) { - int32_t qvalue = input.flatten()[flat_idx].item(); - dequantized_value = - (qvalue - channel_zero_point_int32) * channel_scale_float; - } else if (dtype == at::kLong) { - int64_t qvalue = input.flatten()[flat_idx].item(); - dequantized_value = - (qvalue - channel_zero_point_int32) * channel_scale_float; - } else { - throw std::runtime_error("Unsupported input dtype"); - } - - // Store the result based on output dtype - if (out_dtype == at::kFloat) { - output.flatten()[flat_idx] = static_cast(dequantized_value); - } else if (out_dtype == at::kDouble) { - output.flatten()[flat_idx] = dequantized_value; - } else if (out_dtype == at::kHalf) { - output.flatten()[flat_idx] = static_cast(dequantized_value); - } - } - - return output; -} - -// Forward declaration of implementation functions -void test_vulkan_dequantize_per_token_impl( - const std::vector& input_sizes, - const std::vector& scales, - const std::vector& zero_points, - int64_t quant_min, - int64_t quant_max, - at::ScalarType dtype, - at::ScalarType out_dtype, - const vkcompute::utils::StorageType in_storage, - const vkcompute::utils::StorageType out_storage); - -void test_vulkan_dequantize_per_channel_impl( - const std::vector& input_sizes, - const std::vector& scales, - const std::vector& zero_points, - int64_t axis, - int64_t quant_min, - int64_t quant_max, - at::ScalarType dtype, - at::ScalarType out_dtype, - const vkcompute::utils::StorageType in_storage, - const vkcompute::utils::StorageType out_storage); - -void test_vulkan_dequantize_per_tensor_tensor_impl( - const std::vector& input_sizes, - float scale, - int zero_point, - int64_t quant_min, - int64_t quant_max, - at::ScalarType dtype, - at::ScalarType out_dtype, - const vkcompute::utils::StorageType in_storage, - const vkcompute::utils::StorageType out_storage); - -// Wrapper function to test both buffer and texture storage types -void test_vulkan_dequantize_per_token( - const std::vector& input_sizes, - const std::vector& scales, - const std::vector& zero_points, - int64_t quant_min, - int64_t quant_max, - at::ScalarType dtype, - at::ScalarType out_dtype) { - // Test with buffer storage - test_vulkan_dequantize_per_token_impl( - input_sizes, - scales, - zero_points, - quant_min, - quant_max, - dtype, - out_dtype, - vkcompute::utils::kBuffer, - vkcompute::utils::kBuffer); - - // Telling the system to expect a float instead of a double - // since the shader can only return 32bit anyways - if (out_dtype == at::kDouble) { - out_dtype = at::kFloat; - } - - // Test with texture storage - test_vulkan_dequantize_per_token_impl( - input_sizes, - scales, - zero_points, - quant_min, - quant_max, - dtype, - out_dtype, - vkcompute::utils::kTexture3D, - vkcompute::utils::kTexture3D); -} - -// Wrapper function to test both buffer and texture storage types -void test_vulkan_dequantize_per_channel( - const std::vector& input_sizes, - const std::vector& scales, - const std::vector& zero_points, - int64_t axis, - int64_t quant_min, - int64_t quant_max, - at::ScalarType dtype, - at::ScalarType out_dtype) { - // Test with buffer storage - test_vulkan_dequantize_per_channel_impl( - input_sizes, - scales, - zero_points, - axis, - quant_min, - quant_max, - dtype, - out_dtype, - vkcompute::utils::kBuffer, - vkcompute::utils::kBuffer); - - // Telling the system to expect a float instead of a double - // since the shader can only return 32bit anyways - if (out_dtype == at::kDouble) { - out_dtype = at::kFloat; - } - - // Test with texture storage - test_vulkan_dequantize_per_channel_impl( - input_sizes, - scales, - zero_points, - axis, - quant_min, - quant_max, - dtype, - out_dtype, - vkcompute::utils::kTexture3D, - vkcompute::utils::kTexture3D); -} - -// Wrapper function to test both buffer and texture storage types -void test_vulkan_dequantize_per_tensor_tensor( - const std::vector& input_sizes, - float scale, - int zero_point, - int64_t quant_min, - int64_t quant_max, - at::ScalarType dtype, - at::ScalarType out_dtype) { - // Test with buffer storage - test_vulkan_dequantize_per_tensor_tensor_impl( - input_sizes, - scale, - zero_point, - quant_min, - quant_max, - dtype, - out_dtype, - vkcompute::utils::kBuffer, - vkcompute::utils::kBuffer); - - // Telling the system to expect a float instead of a double - // since the shader can only return 32bit anyways - if (out_dtype == at::kDouble) { - out_dtype = at::kFloat; - } - - // Test with texture storage - test_vulkan_dequantize_per_tensor_tensor_impl( - input_sizes, - scale, - zero_point, - quant_min, - quant_max, - dtype, - out_dtype, - vkcompute::utils::kTexture3D, - vkcompute::utils::kTexture3D); -} - -void test_reference_dequantize_per_tensor( - const std::vector& input_sizes, - float scale, - int zero_point, - int64_t quant_min, - int64_t quant_max, - at::ScalarType dtype, - at::ScalarType out_dtype) { - check_dequantize_args(quant_min, quant_max, dtype, out_dtype); - std::vector input_sizes_int64( - input_sizes.begin(), input_sizes.end()); - - // Create a quantized input tensor with values from quant_min to quant_max - at::Tensor input; - if (dtype == at::kByte) { - input = at::zeros(input_sizes_int64, at::device(at::kCPU).dtype(at::kByte)); - } else if (dtype == at::kChar) { - input = at::zeros(input_sizes_int64, at::device(at::kCPU).dtype(at::kChar)); - } else if (dtype == at::kShort) { - input = - at::zeros(input_sizes_int64, at::device(at::kCPU).dtype(at::kShort)); - } else if (dtype == at::kInt) { - input = at::zeros(input_sizes_int64, at::device(at::kCPU).dtype(at::kInt)); - } else { - input = at::zeros(input_sizes_int64, at::device(at::kCPU).dtype(at::kLong)); - } - - // Fill with a simple pattern: values from quant_min to quant_max in steps - float step = 1.0f; - if (input.numel() > 1) { - step = static_cast(quant_max - quant_min) / (input.numel() - 1); - } - - auto flat_input = input.flatten(); - for (int i = 0; i < flat_input.numel(); i++) { - int64_t qvalue = quant_min + i * step; - if (dtype == at::kByte) { - flat_input[i] = static_cast(qvalue); - } else if (dtype == at::kChar) { - flat_input[i] = static_cast(qvalue); - } else if (dtype == at::kShort) { - flat_input[i] = static_cast(qvalue); - } else if (dtype == at::kInt) { - flat_input[i] = static_cast(qvalue); - } else if (dtype == at::kLong) { - flat_input[i] = static_cast(qvalue); - } - } - - // Reshape back to original dimensions - input = flat_input.reshape(input_sizes_int64); - - // Get reference output - at::Tensor reference_out = dequantize_per_tensor_reference_impl( - input, scale, zero_point, quant_min, quant_max, dtype, out_dtype); - - // Get implementation output - at::Tensor impl_out = torch::executor::native::dequantize_per_tensor_aten( - input, scale, zero_point, quant_min, quant_max, dtype, out_dtype); - - // Compare outputs - const bool output_correct = at::allclose(reference_out, impl_out); - if (!output_correct) { - std::cout << "\n" - << "Failed with parameters: " << std::endl; - std::cout << " scale: " << scale << std::endl; - std::cout << " zero_point: " << zero_point << std::endl; - std::cout << " quant_min: " << quant_min << std::endl; - std::cout << " quant_max: " << quant_max << std::endl; - std::cout << " input dtype: " << dtype << std::endl; - std::cout << " output dtype: " << out_dtype << std::endl; - - std::cout << "input:" << std::endl; - std::cout << input << std::endl; - std::cout << "reference:" << std::endl; - std::cout << reference_out << std::endl; - std::cout << "implementation:" << std::endl; - std::cout << impl_out << std::endl; - } - - ASSERT_TRUE(output_correct); -} - -TEST( - VulkanDequantizePerTensorTest, - test_reference_dequantize_per_tensor_uint8_to_float) { - test_reference_dequantize_per_tensor( - {2, 3, 4}, // input sizes - 0.1, // scale - 5, // zero_point - 0, // quant_min - 255, // quant_max - at::kByte, // input dtype - at::kFloat); // output dtype -} - -TEST( - VulkanDequantizePerTensorTest, - test_reference_dequantize_per_tensor_int8_to_float) { - test_reference_dequantize_per_tensor( - {3, 4, 5}, // input sizes - 0.05, // scale - 0, // zero_point - -128, // quant_min - 127, // quant_max - at::kChar, // input dtype - at::kFloat); // output dtype -} - -TEST( - VulkanDequantizePerTensorTest, - test_reference_dequantize_per_tensor_int32_to_float) { - test_reference_dequantize_per_tensor( - {4, 6, 2}, // input sizes - 0.2, // scale - 2, // zero_point - std::numeric_limits::min(), // quant_min - std::numeric_limits::max(), // quant_max - at::kInt, // input dtype - at::kFloat); // output dtype -} - -TEST( - VulkanDequantizePerTensorTest, - test_reference_dequantize_per_tensor_uint8_to_half) { - test_reference_dequantize_per_tensor( - {7, 4}, // input sizes - 0.1, // scale - 10, // zero_point - 0, // quant_min - 255, // quant_max - at::kByte, // input dtype (uint8) - at::kHalf); // output dtype -} - -TEST( - VulkanDequantizePerTensorTest, - test_reference_dequantize_per_tensor_int32_to_half) { - test_reference_dequantize_per_tensor( - {2, 6, 5}, // input sizes - 0.3, // scale - -10, // zero_point - std::numeric_limits::min(), // quant_min - std::numeric_limits::max(), // quant_max - at::kInt, // input dtype - at::kHalf); // output dtype -} - -// No Vulkan tests for quantized_decomposed.dequantize_per_tensor.default -// because it is not going to be implemented in Vulkan since we will -// be handling any future calls to this op via the export stage - -void test_reference_dequantize_per_token( - const std::vector& input_sizes, - const std::vector& scales, - const std::vector& zero_points, - int64_t quant_min, - int64_t quant_max, - at::ScalarType dtype, - at::ScalarType out_dtype) { - check_dequantize_args(quant_min, quant_max, dtype, out_dtype); - int num_tokens = 1; - for (int i = 0; i < input_sizes.size() - 1; i++) { - num_tokens *= input_sizes[i]; - } - - ASSERT_EQ(num_tokens, scales.size()); - ASSERT_EQ(num_tokens, zero_points.size()); - - // Create input tensor with quantized values - std::vector input_sizes_int64( - input_sizes.begin(), input_sizes.end()); - at::Tensor input; - if (dtype == at::kByte) { - input = at::zeros(input_sizes_int64, at::device(at::kCPU).dtype(at::kByte)); - } else if (dtype == at::kChar) { - input = at::zeros(input_sizes_int64, at::device(at::kCPU).dtype(at::kChar)); - } else if (dtype == at::kShort) { - input = - at::zeros(input_sizes_int64, at::device(at::kCPU).dtype(at::kShort)); - } else if (dtype == at::kInt) { - input = at::zeros(input_sizes_int64, at::device(at::kCPU).dtype(at::kInt)); - } else { - input = at::zeros(input_sizes_int64, at::device(at::kCPU).dtype(at::kLong)); - } - - // Fill with a simple pattern: values from quant_min to quant_max in steps - at::Tensor reshaped_input = input.reshape({num_tokens, input.size(-1)}); - for (int token_idx = 0; token_idx < num_tokens; token_idx++) { - float step = 1.0f; - if (input.size(-1) > 1) { - step = static_cast(quant_max - quant_min) / (input.size(-1) - 1); - } - - for (int i = 0; i < input.size(-1); i++) { - int64_t qvalue = quant_min + i * step; - if (dtype == at::kByte) { - reshaped_input[token_idx][i] = static_cast(qvalue); - } else if (dtype == at::kChar) { - reshaped_input[token_idx][i] = static_cast(qvalue); - } else if (dtype == at::kShort) { - reshaped_input[token_idx][i] = static_cast(qvalue); - } else if (dtype == at::kInt) { - reshaped_input[token_idx][i] = static_cast(qvalue); - } else if (dtype == at::kLong) { - reshaped_input[token_idx][i] = static_cast(qvalue); - } - } - } - - // Reshape back to original dimensions - input = reshaped_input.reshape(input_sizes_int64); - - // Create scale and zero_point tensors - at::Tensor scale_tensor = - at::tensor(scales, at::device(at::kCPU).dtype(at::kFloat)); - at::Tensor zero_point_tensor = - at::tensor(zero_points, at::device(at::kCPU).dtype(at::kLong)); - - // Get reference output - at::Tensor reference_out = dequantize_per_token_reference_impl( - input, - scale_tensor, - zero_point_tensor, - quant_min, - quant_max, - dtype, - out_dtype); - - // Get implementation output - at::Tensor impl_out = torch::executor::native::dequantize_per_token_aten( - input, - scale_tensor, - zero_point_tensor, - quant_min, - quant_max, - dtype, - out_dtype); - - // Compare outputs - const bool output_correct = at::allclose(reference_out, impl_out); - if (!output_correct) { - std::cout << "\n" - << "Failed with parameters: " << std::endl; - std::cout << " scale(s):"; - for (size_t i = 0; i < scales.size(); i++) { - std::cout << " " << scales[i] << " "; - } - std::cout << "" << std::endl; - std::cout << " zero_point(s):"; - for (size_t i = 0; i < zero_points.size(); i++) { - std::cout << " " << zero_points[i] << " "; - } - std::cout << "" << std::endl; - std::cout << " quant_min: " << quant_min << std::endl; - std::cout << " quant_max: " << quant_max << std::endl; - std::cout << " input dtype: " << dtype << std::endl; - std::cout << " output dtype: " << out_dtype << std::endl; - - std::cout << "input:" << std::endl; - std::cout << input << std::endl; - std::cout << "reference:" << std::endl; - std::cout << reference_out << std::endl; - std::cout << "implementation:" << std::endl; - std::cout << impl_out << std::endl; - } - - ASSERT_TRUE(output_correct); -} - -void test_vulkan_dequantize_per_token_impl( - const std::vector& input_sizes, - const std::vector& scales, - const std::vector& zero_points, - int64_t quant_min, - int64_t quant_max, - at::ScalarType dtype, - at::ScalarType out_dtype, - const vkcompute::utils::StorageType in_storage, - const vkcompute::utils::StorageType out_storage) { - check_dequantize_args(quant_min, quant_max, dtype, out_dtype); - int num_tokens = 1; - for (int i = 0; i < input_sizes.size() - 1; i++) { - num_tokens *= input_sizes[i]; - } - - ASSERT_EQ(num_tokens, scales.size()); - ASSERT_EQ(num_tokens, zero_points.size()); - - // Create input tensor with quantized values - std::vector input_sizes_int64( - input_sizes.begin(), input_sizes.end()); - at::Tensor input; - if (dtype == at::kByte) { - input = at::zeros(input_sizes_int64, at::device(at::kCPU).dtype(at::kByte)); - } else if (dtype == at::kChar) { - input = at::zeros(input_sizes_int64, at::device(at::kCPU).dtype(at::kChar)); - } else if (dtype == at::kShort) { - input = - at::zeros(input_sizes_int64, at::device(at::kCPU).dtype(at::kShort)); - } else if (dtype == at::kInt) { - input = at::zeros(input_sizes_int64, at::device(at::kCPU).dtype(at::kInt)); - } else { - input = at::zeros(input_sizes_int64, at::device(at::kCPU).dtype(at::kLong)); - } - - // Fill with a simple pattern: values from quant_min to quant_max in steps - at::Tensor reshaped_input = input.reshape({num_tokens, input.size(-1)}); - for (int token_idx = 0; token_idx < num_tokens; token_idx++) { - float step = 1.0f; - if (input.size(-1) > 1) { - step = static_cast(quant_max - quant_min) / (input.size(-1) - 1); - } - - for (int i = 0; i < input.size(-1); i++) { - int64_t qvalue = quant_min + i * step; - if (dtype == at::kByte) { - reshaped_input[token_idx][i] = static_cast(qvalue); - } else if (dtype == at::kChar) { - reshaped_input[token_idx][i] = static_cast(qvalue); - } else if (dtype == at::kShort) { - reshaped_input[token_idx][i] = static_cast(qvalue); - } else if (dtype == at::kInt) { - reshaped_input[token_idx][i] = static_cast(qvalue); - } else if (dtype == at::kLong) { - reshaped_input[token_idx][i] = static_cast(qvalue); - } - } - } - - // Reshape back to original dimensions - input = reshaped_input.reshape(input_sizes_int64); - - // Create scale and zero_point tensors - at::Tensor scale_tensor = - at::tensor(scales, at::device(at::kCPU).dtype(at::kFloat)); - at::Tensor zero_point_tensor = - at::tensor(zero_points, at::device(at::kCPU).dtype(at::kLong)); - - // Get reference output - at::Tensor reference_out = torch::executor::native::dequantize_per_token_aten( - input, - scale_tensor, - zero_point_tensor, - quant_min, - quant_max, - dtype, - out_dtype); - - // Build Vulkan dequantize_per_token graph - using namespace vkcompute; - - GraphConfig config; - config.set_storage_type_override(in_storage); - ComputeGraph graph(config); - - IOValueRef r_input = graph.add_input_tensor( - input.sizes().vec(), from_at_scalartype(dtype), in_storage); - IOValueRef r_scale = graph.add_input_tensor( - scale_tensor.sizes().vec(), - vkapi::kFloat, - utils::kBuffer, - utils::kWidthPacked); - IOValueRef r_zero_point = graph.add_input_tensor( - zero_point_tensor.sizes().vec(), - vkapi::kInt, - utils::kBuffer, - utils::kWidthPacked); - - const ValueRef r_quant_min = graph.add_scalar(quant_min); - const ValueRef r_quant_max = graph.add_scalar(quant_max); - - const ValueRef r_out = graph.add_tensor( - input.sizes().vec(), from_at_scalartype(out_dtype), out_storage); - - const ValueRef r_dtype = - graph.add_scalar(static_cast(out_dtype)); - - VK_GET_OP_FN("quantized_decomposed.dequantize_per_token.default") - (graph, - { - r_input.value, - r_scale.value, - r_zero_point.value, - r_quant_min, - r_quant_max, - r_dtype, - r_dtype, - r_out, - }); - - ValueRef staging_out = graph.set_output_tensor(r_out); - - graph.prepare(); - - graph.prepack(); - - // Copy input data to GPU - graph.copy_into_staging( - r_input.staging, input.const_data_ptr(), input.numel()); - - // Convert scale tensor to float and copy to GPU - at::Tensor scale_float = scale_tensor.to(at::kFloat); - graph.copy_into_staging( - r_scale.staging, scale_float.const_data_ptr(), scale_float.numel()); - - // Convert zero_point tensor to int and copy to GPU - at::Tensor zero_point_int = zero_point_tensor.to(at::kInt); - graph.copy_into_staging( - r_zero_point.staging, - zero_point_int.const_data_ptr(), - zero_point_int.numel()); - - // Execute the graph - graph.execute(); - - // Copy output data back to CPU - at::Tensor vk_out = at::empty_like(reference_out).contiguous(); - graph.copy_from_staging( - staging_out, vk_out.mutable_data_ptr(), vk_out.numel()); - - // Compare outputs with appropriate tolerance for half precision - bool output_correct; - if (out_dtype == at::kHalf) { - // Use higher tolerance for half precision due to limited precision - output_correct = - at::allclose(reference_out, vk_out, /*rtol=*/1e-2, /*atol=*/1e-2); - } else { - output_correct = - at::allclose(reference_out, vk_out, /*rtol=*/1e-5, /*atol=*/1e-5); - } - if (!output_correct) { - std::cout << "\n" - << "Failed with parameters: " << std::endl; - std::cout << " scale(s):"; - for (size_t i = 0; i < scales.size(); i++) { - std::cout << " " << scales[i] << " "; - } - std::cout << "" << std::endl; - std::cout << " zero_point(s):"; - for (size_t i = 0; i < zero_points.size(); i++) { - std::cout << " " << zero_points[i] << " "; - } - std::cout << "" << std::endl; - std::cout << " quant_min: " << quant_min << std::endl; - std::cout << " quant_max: " << quant_max << std::endl; - std::cout << " storage type: " - << (in_storage == vkcompute::utils::kBuffer ? "buffer" - : "texture") - << std::endl; - std::cout << " input dtype: " << dtype << std::endl; - std::cout << " output dtype: " << out_dtype << std::endl; - - std::cout << "input:" << std::endl; - std::cout << input << std::endl; - std::cout << "reference:" << std::endl; - std::cout << reference_out << std::endl; - std::cout << "vulkan:" << std::endl; - std::cout << vk_out << std::endl; - } - - ASSERT_TRUE(output_correct); -} - -TEST( - VulkanDequantizePerTokenTest, - test_reference_dequantize_per_token_uint8_to_float) { - std::vector scales = {0.1, 0.2, 0.3, 0.4, 0.5, 0.6}; - std::vector zero_points = {5, 10, 15, 20, 25, 30}; - - test_reference_dequantize_per_token( - {2, 3, 4}, // input sizes (2*3=6 tokens) - scales, - zero_points, - 0, // quant_min - 255, // quant_max - at::kByte, // input dtype - at::kFloat); // output dtype -} - -TEST( - VulkanDequantizePerTokenTest, - test_reference_dequantize_per_token_int8_to_float) { - std::vector scales = {0.05, 0.1, 0.15, 0.2}; - std::vector zero_points = {0, -5, 5, 10}; - - test_reference_dequantize_per_token( - {2, 2, 5}, // input sizes (2*2=4 tokens) - scales, - zero_points, - -128, // quant_min - 127, // quant_max - at::kChar, // input dtype - at::kFloat); // output dtype -} - -TEST( - VulkanDequantizePerTokenTest, - test_reference_dequantize_per_token_int32_to_float) { - std::vector scales = {0.05, 0.1, 0.15, 0.2}; - std::vector zero_points = {0, -5, 5, 10}; - - test_reference_dequantize_per_token( - {2, 2, 10}, // input sizes (2*2=4 tokens) - scales, - zero_points, - std::numeric_limits::min(), // quant_min - std::numeric_limits::max(), // quant_max - at::kInt, // input dtype - at::kFloat); // output dtype -} - -TEST( - VulkanDequantizePerTokenTest, - test_reference_dequantize_per_token_int8_to_half) { - std::vector scales = {0.05, 0.1, 0.15, 0.2}; - std::vector zero_points = {0, -5, 5, 10}; - - test_reference_dequantize_per_token( - {4, 1, 5}, // input sizes (4*1=4 tokens) - scales, - zero_points, - -128, // quant_min - 127, // quant_max - at::kChar, // input dtype (int8) - at::kHalf); // output dtype -} - -TEST( - VulkanDequantizePerTokenTest, - test_reference_dequantize_per_token_int32_to_half) { - std::vector scales = {0.05, 0.1}; - std::vector zero_points = {0, -5}; - - test_reference_dequantize_per_token( - {2, 2}, // input sizes (2 tokens) - scales, - zero_points, - std::numeric_limits::min(), // quant_min - std::numeric_limits::max(), // quant_max - at::kInt, // input dtype - at::kHalf); // output dtype -} - -TEST( - VulkanDequantizePerTokenTest, - test_vulkan_dequantize_per_token_uint8_to_float) { - if (!vkcompute::api::context() - ->adapter_ptr() - ->has_full_int8_buffers_support()) { - GTEST_SKIP(); - } - std::vector scales = {0.1, 0.2, 0.3, 0.4, 0.5, 0.6}; - std::vector zero_points = {5, 10, 15, 20, 25, 30}; - - test_vulkan_dequantize_per_token( - {2, 3, 6}, // input sizes (2*3=6 tokens) - scales, - zero_points, - 0, // quant_min - 255, // quant_max - at::kByte, // input dtype - at::kFloat); // output dtype -} - -TEST( - VulkanDequantizePerTokenTest, - test_vulkan_dequantize_per_token_int8_to_float) { - if (!vkcompute::api::context() - ->adapter_ptr() - ->has_full_int8_buffers_support()) { - GTEST_SKIP(); - } - std::vector scales = {0.05, 0.0}; - std::vector zero_points = {10, -5}; - - test_vulkan_dequantize_per_token( - {2, 2}, // input sizes (2*2=4 tokens) - scales, - zero_points, - -128, // quant_min - 127, // quant_max - at::kChar, // input dtype - at::kFloat); // output dtype -} - -TEST( - VulkanDequantizePerTokenTest, - test_vulkan_dequantize_per_token_int32_to_float) { - std::vector scales = { - 0.0001, 0.0002, 0.0003, 0.0, 0.0011, 0.0102, 0.1003, 0.0}; - std::vector zero_points = {100, -100, 50, -50, 12, -6, 4, -24}; - - test_vulkan_dequantize_per_token( - {2, 2, 2, 12}, // input sizes (2*2=4 tokens) - scales, - zero_points, - -2147483648, // quant_min - 2147483647, // quant_max - at::kInt, // input dtype - at::kFloat); // output dtype -} - -TEST( - VulkanDequantizePerTokenTest, - test_vulkan_dequantize_per_token_int8_to_half) { - if (!vkcompute::api::context() - ->adapter_ptr() - ->has_full_int8_buffers_support()) { - GTEST_SKIP(); - } - if (!vkcompute::api::context() - ->adapter_ptr() - ->has_full_float16_buffers_support()) { - GTEST_SKIP(); - } - std::vector scales = {0.05, 0.2}; - std::vector zero_points = {2, -5}; - - test_vulkan_dequantize_per_token( - {2, 2}, // input sizes (2=4 tokens) - scales, - zero_points, - -128, // quant_min - 127, // quant_max - at::kChar, // input dtype - at::kHalf); // output dtype -} - -TEST( - VulkanDequantizePerTokenTest, - test_vulkan_dequantize_per_token_int32_to_half) { - if (!vkcompute::api::context() - ->adapter_ptr() - ->has_full_float16_buffers_support()) { - GTEST_SKIP(); - } - // Use much smaller scales to avoid overflow to infinity in half precision - // Half precision max value is ~65504, so with int32 values around 2e9, - // we need scales smaller than 65504/2e9 ≈ 3e-5 to avoid overflow - std::vector scales = {1e-5, 2e-5, 1.5e-5}; - std::vector zero_points = {20, -15, 1}; - - test_vulkan_dequantize_per_token( - {3, 6}, // input sizes (3 tokens) - scales, - zero_points, - std::numeric_limits::min(), // quant_min - std::numeric_limits::max(), // quant_max - at::kInt, // input dtype - at::kHalf); // output dtype -} - -TEST( - VulkanDequantizePerTokenTest, - test_vulkan_dequantize_per_token_int8_to_double) { - if (!vkcompute::api::context() - ->adapter_ptr() - ->has_full_int8_buffers_support()) { - GTEST_SKIP(); - } - std::vector scales = {0.05, 0.001}; - std::vector zero_points = {10, -5}; - - test_vulkan_dequantize_per_token( - {2, 2}, // input sizes (2 tokens) - scales, - zero_points, - -128, // quant_min - 127, // quant_max - at::kChar, // input dtype - at::kDouble); // output dtype -} - -void test_reference_dequantize_per_channel( - const std::vector& input_sizes, - const std::vector& scales, - const std::vector& zero_points, - int64_t axis, - int64_t quant_min, - int64_t quant_max, - at::ScalarType dtype, - at::ScalarType out_dtype) { - check_dequantize_args(quant_min, quant_max, dtype, out_dtype); - check_dequantize_per_channel_args(input_sizes, scales, zero_points, axis); - - std::vector input_sizes_int64( - input_sizes.begin(), input_sizes.end()); - - // Create input tensor with quantized values - at::Tensor input; - if (dtype == at::kByte) { - input = at::zeros(input_sizes_int64, at::device(at::kCPU).dtype(at::kByte)); - } else if (dtype == at::kChar) { - input = at::zeros(input_sizes_int64, at::device(at::kCPU).dtype(at::kChar)); - } else if (dtype == at::kShort) { - input = - at::zeros(input_sizes_int64, at::device(at::kCPU).dtype(at::kShort)); - } else if (dtype == at::kInt) { - input = at::zeros(input_sizes_int64, at::device(at::kCPU).dtype(at::kInt)); - } else { - input = at::zeros(input_sizes_int64, at::device(at::kCPU).dtype(at::kLong)); - } - - // Fill with a simple pattern: values from quant_min to quant_max in steps - float step = 1.0f; - if (input.numel() > 1) { - step = static_cast(quant_max - quant_min) / (input.numel() - 1); - } - - auto flat_input = input.flatten(); - for (int i = 0; i < flat_input.numel(); i++) { - int64_t qvalue = quant_min + i * step; - if (dtype == at::kByte) { - flat_input[i] = static_cast(qvalue); - } else if (dtype == at::kChar) { - flat_input[i] = static_cast(qvalue); - } else if (dtype == at::kShort) { - flat_input[i] = static_cast(qvalue); - } else if (dtype == at::kInt) { - flat_input[i] = static_cast(qvalue); - } else if (dtype == at::kLong) { - flat_input[i] = static_cast(qvalue); - } - } - - // Reshape back to original dimensions - input = flat_input.reshape(input_sizes_int64); - - // Create scale and zero_point tensors - at::Tensor scale_tensor = - at::tensor(scales, at::device(at::kCPU).dtype(at::kDouble)); - at::Tensor zero_point_tensor = - at::tensor(zero_points, at::device(at::kCPU).dtype(at::kLong)); - - // Get reference output - at::Tensor my_ref = dequantize_per_channel_reference_impl( - input, - scale_tensor, - zero_point_tensor, - axis, - quant_min, - quant_max, - dtype, - out_dtype); - - // Get implementation output - at::Tensor cpu_ref = torch::executor::native::dequantize_per_channel_aten( - input, - scale_tensor, - zero_point_tensor, - axis, - quant_min, - quant_max, - dtype, - out_dtype); - - // Compare outputs - const bool output_correct = at::allclose(my_ref, cpu_ref); - if (!output_correct) { - std::cout << "\n" - << "Failed with parameters: " << std::endl; - std::cout << " axis: " << axis << std::endl; - std::cout << " input sizes:"; - for (size_t i = 0; i < input_sizes.size(); i++) { - std::cout << " " << input_sizes[i] << " "; - } - std::cout << "" << std::endl; - std::cout << " scale(s):"; - for (size_t i = 0; i < scales.size(); i++) { - std::cout << " " << scales[i] << " "; - } - std::cout << "" << std::endl; - std::cout << " zero_point(s):"; - for (size_t i = 0; i < zero_points.size(); i++) { - std::cout << " " << zero_points[i] << " "; - } - std::cout << "" << std::endl; - std::cout << " quant_min: " << quant_min << std::endl; - std::cout << " quant_max: " << quant_max << std::endl; - std::cout << " input dtype: " << dtype << std::endl; - std::cout << " output dtype: " << out_dtype << std::endl; - - std::cout << "input:" << std::endl; - std::cout << input << std::endl; - std::cout << "cpu_ref:" << std::endl; - std::cout << cpu_ref << std::endl; - std::cout << "my_ref:" << std::endl; - std::cout << my_ref << std::endl; - } - - ASSERT_TRUE(output_correct); -} - -void test_vulkan_dequantize_per_channel_impl( - const std::vector& input_sizes, - const std::vector& scales, - const std::vector& zero_points, - int64_t axis, - int64_t quant_min, - int64_t quant_max, - at::ScalarType dtype, - at::ScalarType out_dtype, - const vkcompute::utils::StorageType in_storage, - const vkcompute::utils::StorageType out_storage) { - check_dequantize_args(quant_min, quant_max, dtype, out_dtype); - check_dequantize_per_channel_args(input_sizes, scales, zero_points, axis); - - std::vector input_sizes_int64( - input_sizes.begin(), input_sizes.end()); - - // Create random float tensor - at::Tensor float_x = - at::rand(input_sizes_int64, at::device(at::kCPU).dtype(at::kFloat)); - - // Create scale and zero_point tensors - at::Tensor scale_tensor = - at::tensor(scales, at::device(at::kCPU).dtype(at::kFloat)); - at::Tensor zero_point_tensor = - at::tensor(zero_points, at::device(at::kCPU).dtype(at::kInt)); - - // Map the dtype to the corresponding quantized type and quantize the float - // tensor - c10::ScalarType qtype; - at::Tensor adjusted_zero_points = zero_point_tensor; - - if (dtype == at::kByte) { - qtype = c10::kQUInt8; - // ATEN ONLY: Adjust zero points for unsigned types (must be non-negative) - adjusted_zero_points = at::clamp_min(zero_point_tensor, 0); - } else if (dtype == at::kChar) { - qtype = c10::kQInt8; - } else if (dtype == at::kInt) { - qtype = c10::kQInt32; - } else { - std::cout << "invalid dtype for ATEN: " << dtype << std::endl; - std::cout << " --> Delegating to c10::kQInt32" << std::endl; - qtype = c10::kQInt32; - } - - // Normalize axis for ATen (ATen doesn't handle negative axes in - // quantize_per_channel) - int64_t normalized_axis = axis; - if (normalized_axis < 0) { - normalized_axis += input_sizes_int64.size(); - } - - // Quantize using ATen - at::Tensor quantized_aten = at::quantize_per_channel( - float_x, scale_tensor, adjusted_zero_points, normalized_axis, qtype); - - // Get ATen dequantized output - at::Tensor aten_out = at::dequantize(quantized_aten).to(out_dtype); - - // Extract the quantized values (int_repr) to use with our implementations - at::Tensor quantized_input = quantized_aten.int_repr().to(dtype); - - // Get reference output using - // torch::executor::native::dequantize_per_channel_aten - at::Tensor reference_out = - torch::executor::native::dequantize_per_channel_aten( - quantized_input, - scale_tensor.to(at::kDouble), - zero_point_tensor.to(at::kLong), - axis, - quant_min, - quant_max, - dtype, - out_dtype); - - // Build Vulkan dequantize_per_channel graph - using namespace vkcompute; - - GraphConfig config; - config.set_storage_type_override(in_storage); - ComputeGraph graph(config); - - // Add tensors to graph - IOValueRef r_input = graph.add_input_tensor( - quantized_input.sizes().vec(), - from_at_scalartype(quantized_input.scalar_type()), - in_storage); - - IOValueRef r_scale = graph.add_input_tensor( - scale_tensor.sizes().vec(), - vkapi::kFloat, - utils::kBuffer, - utils::kWidthPacked); - - IOValueRef r_zero_point = graph.add_input_tensor( - adjusted_zero_points.sizes().vec(), - vkapi::kInt, - utils::kBuffer, - utils::kWidthPacked); - - ValueRef r_out = graph.add_tensor( - quantized_input.sizes().vec(), - from_at_scalartype(out_dtype), - out_storage); - - const ValueRef r_axis = graph.add_scalar(axis); - const ValueRef r_quant_min = graph.add_scalar(quant_min); - const ValueRef r_quant_max = graph.add_scalar(quant_max); - const ValueRef r_dtype = - graph.add_scalar(static_cast(dtype)); - const ValueRef r_output_dtype = - graph.add_scalar(static_cast(out_dtype)); - - VK_GET_OP_FN("quantized_decomposed.dequantize_per_channel.default") - (graph, - { - r_input.value, - r_scale.value, - r_zero_point.value, - r_axis, - r_quant_min, - r_quant_max, - r_dtype, - r_output_dtype, - r_out, - }); - - ValueRef staging_out = graph.set_output_tensor(r_out); - - graph.prepare(); - graph.prepack(); - - // Copy input data to GPU - graph.copy_into_staging( - r_input.staging, - quantized_input.const_data_ptr(), - quantized_input.numel()); - - // copy scale tensor to GPU - graph.copy_into_staging( - r_scale.staging, scale_tensor.const_data_ptr(), scale_tensor.numel()); - - // copy zero_point tensor to GPU - graph.copy_into_staging( - r_zero_point.staging, - zero_point_tensor.const_data_ptr(), - zero_point_tensor.numel()); - - // Execute the graph - graph.execute(); - - // Copy output data back to CPU - at::Tensor vk_out = at::empty_like(reference_out).contiguous(); - graph.copy_from_staging( - staging_out, vk_out.mutable_data_ptr(), vk_out.numel()); - - // Compare outputs with appropriate tolerance for half precision - bool output_correct; - if (out_dtype == at::kHalf) { - // Use higher tolerance for half precision due to limited precision - output_correct = - at::allclose(reference_out, vk_out, /*rtol=*/1e-2, /*atol=*/1e-2); - } else { - output_correct = - at::allclose(reference_out, vk_out, /*rtol=*/1e-5, /*atol=*/1e-5); - } - if (!output_correct) { - std::cout << "\n" - << "Failed with parameters: " << std::endl; - std::cout << " axis: " << axis << std::endl; - std::cout << " input sizes:"; - for (size_t i = 0; i < input_sizes.size(); i++) { - std::cout << " " << input_sizes[i] << " "; - } - std::cout << "" << std::endl; - std::cout << " scale(s):"; - for (size_t i = 0; i < scales.size(); i++) { - std::cout << " " << scales[i] << " "; - } - std::cout << "" << std::endl; - std::cout << " zero_point(s):"; - for (size_t i = 0; i < zero_points.size(); i++) { - std::cout << " " << zero_points[i] << " "; - } - std::cout << "" << std::endl; - std::cout << " quant_min: " << quant_min << std::endl; - std::cout << " quant_max: " << quant_max << std::endl; - std::cout << " input dtype: " << dtype << std::endl; - std::cout << " output dtype: " << out_dtype << std::endl; - std::cout << " storage: " << in_storage << std::endl; - std::cout << std::endl; - - std::cout << "\033[91m quantized_input: \033[0m" << std::endl; - std::cout << quantized_input << std::endl; - std::cout << "\033[91m aten: \033[0m" << std::endl; - std::cout << aten_out << std::endl; - std::cout << "\033[91m reference: \033[0m" << std::endl; - std::cout << reference_out << std::endl; - std::cout << "\033[91m vulkan: \033[0m" << std::endl; - std::cout << vk_out << std::endl; - } - - ASSERT_TRUE(output_correct); -} - -TEST( - VulkanDequantizePerChannelTest, - test_reference_dequantize_per_channel_uint8_to_float_3D_axis0) { - std::vector scales = {0.1, 0.2, 0.3}; - std::vector zero_points = {0, 5, -2}; - - test_reference_dequantize_per_channel( - {3, 4, 2}, // input sizes - scales, - zero_points, - 0, // axis - 0, // quant_min - 255, // quant_max - at::kByte, - at::kFloat); -} - -TEST( - VulkanDequantizePerChannelTest, - test_reference_dequantize_per_channel_int8_to_float_3D_axis2) { - std::vector scales = {0.1, 0.2}; - std::vector zero_points = {0, 5}; - - test_reference_dequantize_per_channel( - {3, 4, 2}, // input sizes - scales, - zero_points, - 2, // axis - -128, // quant_min - 127, // quant_max - at::kChar, - at::kFloat); -} - -TEST( - VulkanDequantizePerChannelTest, - test_reference_dequantize_per_channel_int8_to_float_3D_axisn1) { - std::vector scales = {0.1, 0.2}; - std::vector zero_points = {0, 5}; - - test_reference_dequantize_per_channel( - {3, 4, 2}, // input sizes - scales, - zero_points, - -1, // axis - -128, // quant_min - 127, // quant_max - at::kChar, - at::kFloat); -} - -TEST( - VulkanDequantizePerChannelTest, - test_reference_dequantize_per_channel_int32_to_float_4D_axis0) { - std::vector scales = {0.1, 0.2, 0.00002}; - std::vector zero_points = {0, 5, -4}; - - test_reference_dequantize_per_channel( - {3, 4, 2, 5}, // input sizes - scales, - zero_points, - 0, // axis - std::numeric_limits::min(), // quant_min - std::numeric_limits::max(), // quant_max - at::kInt, - at::kFloat); -} - -// END OF REFERENCE TESTS - -TEST( - VulkanDequantizePerChannelTest, - test_vulkan_dequantize_per_channel_int8_to_float_axis0) { - if (!vkcompute::api::context() - ->adapter_ptr() - ->has_full_int8_buffers_support()) { - GTEST_SKIP(); - } - std::vector scales(9, 0.1f); - std::vector zero_points(9, 2); - - // 1D Tensor - test_vulkan_dequantize_per_channel( - {9}, // input sizes - scales, - zero_points, - 0, // axis - -128, // quant_min - 127, // quant_max - at::kChar, - at::kFloat); - - // 2D Tensor - test_vulkan_dequantize_per_channel( - {9, 14}, // input sizes - scales, - zero_points, - 0, // axis - -128, // quant_min - 127, // quant_max - at::kChar, - at::kFloat); - - // 3D Tensor - test_vulkan_dequantize_per_channel( - {9, 7, 11}, // input sizes - scales, - zero_points, - 0, // axis - -128, // quant_min - 127, // quant_max - at::kChar, - at::kFloat); - - // 4D Tensor - test_vulkan_dequantize_per_channel( - {9, 17, 5, 5}, // input sizes - scales, - zero_points, - 0, // axis - -128, // quant_min - 127, // quant_max - at::kChar, - at::kFloat); - - // 4D Tensor (negative axis) - test_vulkan_dequantize_per_channel( - {5, 17, 5, 9}, // input sizes - scales, - zero_points, - -1, // axis - -128, // quant_min - 127, // quant_max - at::kChar, - at::kFloat); -} - -TEST( - VulkanDequantizePerChannelTest, - test_vulkan_dequantize_per_channel_int8_to_float_axis1) { - if (!vkcompute::api::context() - ->adapter_ptr() - ->has_full_int8_buffers_support()) { - GTEST_SKIP(); - } - std::vector scales(14, 0.001f); - std::vector zero_points(14, -5); - - // 2D Tensor - test_vulkan_dequantize_per_channel( - {9, 14}, // input sizes - scales, - zero_points, - 1, // axis - -128, // quant_min - 127, // quant_max - at::kChar, - at::kFloat); - - // 3D Tensor - test_vulkan_dequantize_per_channel( - {9, 14, 11}, // input sizes - scales, - zero_points, - 1, // axis - -128, // quant_min - 127, // quant_max - at::kChar, - at::kFloat); - - // 4D Tensor - test_vulkan_dequantize_per_channel( - {9, 14, 5, 5}, // input sizes - scales, - zero_points, - 1, // axis - -128, // quant_min - 127, // quant_max - at::kChar, - at::kFloat); - - // 4D Tensor (negative axis) - test_vulkan_dequantize_per_channel( - {9, 7, 14, 5}, // input sizes - scales, - zero_points, - -2, // axis - -128, // quant_min - 127, // quant_max - at::kChar, - at::kFloat); -} - -TEST( - VulkanDequantizePerChannelTest, - test_vulkan_dequantize_per_channel_int8_to_float_axis2) { - if (!vkcompute::api::context() - ->adapter_ptr() - ->has_full_int8_buffers_support()) { - GTEST_SKIP(); - } - std::vector scales(11, 0.5f); - std::vector zero_points(11, 12); - - // 3D Tensor - test_vulkan_dequantize_per_channel( - {9, 14, 11}, // input sizes - scales, - zero_points, - 2, // axis - -128, // quant_min - 127, // quant_max - at::kChar, - at::kFloat); - - // 4D Tensor - test_vulkan_dequantize_per_channel( - {9, 14, 11, 5}, // input sizes - scales, - zero_points, - 2, // axis - -128, // quant_min - 127, // quant_max - at::kChar, - at::kFloat); - - // 4D Tensor (negative axis) - test_vulkan_dequantize_per_channel( - {9, 11, 14, 5}, // input sizes - scales, - zero_points, - -3, // axis - -128, // quant_min - 127, // quant_max - at::kChar, - at::kFloat); -} - -TEST( - VulkanDequantizePerChannelTest, - test_vulkan_dequantize_per_channel_int8_to_float_axis3) { - if (!vkcompute::api::context() - ->adapter_ptr() - ->has_full_int8_buffers_support()) { - GTEST_SKIP(); - } - std::vector scales(7, 0.5f); - std::vector zero_points(7, 12); - - // 4D Tensor - test_vulkan_dequantize_per_channel( - {9, 14, 11, 7}, // input sizes - scales, - zero_points, - 3, // axis - -128, // quant_min - 127, // quant_max - at::kChar, - at::kFloat); - - // 4D Tensor (negative axis) - test_vulkan_dequantize_per_channel( - {7, 14, 11, 7}, // input sizes - scales, - zero_points, - -4, // axis - -128, // quant_min - 127, // quant_max - at::kChar, - at::kFloat); -} - -TEST( - VulkanDequantizePerChannelTest, - test_vulkan_dequantize_per_channel_uint8_to_float_comprehensive) { - if (!vkcompute::api::context() - ->adapter_ptr() - ->has_full_int8_buffers_support()) { - GTEST_SKIP(); - } - std::vector scales = {0.1, 0.2, 0.0001, 0.5, 0.02}; - std::vector zero_points = {0, 5, -5, 1, 12}; - - // 4D Tensor - test_vulkan_dequantize_per_channel( - {5, 14, 11, 7}, // input sizes - scales, - zero_points, - 0, // axis - 0, // quant_min - 255, // quant_max - at::kByte, - at::kFloat); - - // 4D Tensor - test_vulkan_dequantize_per_channel( - {9, 5, 11, 7}, // input sizes - scales, - zero_points, - 1, // axis - 0, // quant_min - 255, // quant_max - at::kByte, - at::kFloat); - - // 4D Tensor - test_vulkan_dequantize_per_channel( - {9, 14, 5, 7}, // input sizes - scales, - zero_points, - 2, // axis - 0, // quant_min - 255, // quant_max - at::kByte, - at::kFloat); - - // 4D Tensor - test_vulkan_dequantize_per_channel( - {9, 14, 11, 5}, // input sizes - scales, - zero_points, - 3, // axis - 0, // quant_min - 255, // quant_max - at::kByte, - at::kFloat); - - // 4D Tensor (negative axis) - test_vulkan_dequantize_per_channel( - {5, 14, 11, 7}, // input sizes - scales, - zero_points, - -4, // axis - 0, // quant_min - 255, // quant_max - at::kByte, - at::kFloat); -} - -TEST( - VulkanDequantizePerChannelTest, - test_vulkan_dequantize_per_channel_8bit_to_half) { - if (!vkcompute::api::context() - ->adapter_ptr() - ->has_full_int8_buffers_support()) { - GTEST_SKIP(); - } - if (!vkcompute::api::context() - ->adapter_ptr() - ->has_full_float16_buffers_support()) { - GTEST_SKIP(); - } - std::vector scales = {0.1, 0.2, 0.01, 0.5, 0.02}; - std::vector zero_points = {0, 5, 5, 1, 12}; - - // 4D Tensor - test_vulkan_dequantize_per_channel( - {5, 14, 11, 7}, // input sizes - scales, - zero_points, - 0, // axis - -128, // quant_min - 127, // quant_max - at::kChar, - at::kHalf); - - // 4D Tensor - test_vulkan_dequantize_per_channel( - {9, 5, 11, 7}, // input sizes - scales, - zero_points, - 1, // axis - -128, // quant_min - 127, // quant_max - at::kChar, - at::kHalf); - - // 4D Tensor - test_vulkan_dequantize_per_channel( - {9, 14, 5, 7}, // input sizes - scales, - zero_points, - 2, // axis - 0, // quant_min - 255, // quant_max - at::kByte, - at::kHalf); - - // 4D Tensor - test_vulkan_dequantize_per_channel( - {9, 14, 11, 5}, // input sizes - scales, - zero_points, - 3, // axis - -128, // quant_min - 127, // quant_max - at::kChar, - at::kHalf); - - // 4D Tensor (negative axis) - test_vulkan_dequantize_per_channel( - {5, 14, 11, 7}, // input sizes - scales, - zero_points, - -4, // axis - 0, // quant_min - 255, // quant_max - at::kByte, - at::kHalf); -} - -TEST( - VulkanDequantizePerChannelTest, - test_vulkan_dequantize_per_channel_8bit_to_double) { - if (!vkcompute::api::context() - ->adapter_ptr() - ->has_full_int8_buffers_support()) { - GTEST_SKIP(); - } - std::vector scales = {0.1, 0.2, 0.01, 0.5, 0.02}; - std::vector zero_points = {0, 5, 5, 1, 12}; - - // 4D Tensor - test_vulkan_dequantize_per_channel( - {5, 14, 11, 7}, // input sizes - scales, - zero_points, - 0, // axis - -128, // quant_min - 127, // quant_max - at::kChar, - at::kDouble); - - // 4D Tensor - test_vulkan_dequantize_per_channel( - {9, 5, 11, 7}, // input sizes - scales, - zero_points, - 1, // axis - -128, // quant_min - 127, // quant_max - at::kChar, - at::kDouble); - - // 4D Tensor - test_vulkan_dequantize_per_channel( - {9, 14, 5, 7}, // input sizes - scales, - zero_points, - 2, // axis - 0, // quant_min - 255, // quant_max - at::kByte, - at::kDouble); - - // 4D Tensor - test_vulkan_dequantize_per_channel( - {9, 14, 11, 5}, // input sizes - scales, - zero_points, - 3, // axis - -128, // quant_min - 127, // quant_max - at::kChar, - at::kDouble); - - // 4D Tensor (negative axis) - test_vulkan_dequantize_per_channel( - {5, 14, 11, 7}, // input sizes - scales, - zero_points, - -4, // axis - 0, // quant_min - 255, // quant_max - at::kByte, - at::kDouble); -} - -void test_vulkan_dequantize_per_tensor_tensor_impl( - const std::vector& input_sizes, - float scale, - int zero_point, - int64_t quant_min, - int64_t quant_max, - at::ScalarType dtype, - at::ScalarType out_dtype, - const vkcompute::utils::StorageType in_storage, - const vkcompute::utils::StorageType out_storage) { - check_dequantize_args(quant_min, quant_max, dtype, out_dtype); - std::vector input_sizes_int64( - input_sizes.begin(), input_sizes.end()); - - // Create a quantized input tensor with values from quant_min to quant_max - at::Tensor input; - if (dtype == at::kByte) { - input = at::zeros(input_sizes_int64, at::device(at::kCPU).dtype(at::kByte)); - } else if (dtype == at::kChar) { - input = at::zeros(input_sizes_int64, at::device(at::kCPU).dtype(at::kChar)); - } else if (dtype == at::kShort) { - input = - at::zeros(input_sizes_int64, at::device(at::kCPU).dtype(at::kShort)); - } else if (dtype == at::kInt) { - input = at::zeros(input_sizes_int64, at::device(at::kCPU).dtype(at::kInt)); - } else { - input = at::zeros(input_sizes_int64, at::device(at::kCPU).dtype(at::kLong)); - } - - // Fill with a simple pattern: values from quant_min to quant_max in steps - float step = 1.0f; - if (input.numel() > 1) { - step = static_cast(quant_max - quant_min) / (input.numel() - 1); - } - - auto flat_input = input.flatten(); - for (int i = 0; i < flat_input.numel(); i++) { - int64_t qvalue = quant_min + i * step; - if (dtype == at::kByte) { - flat_input[i] = static_cast(qvalue); - } else if (dtype == at::kChar) { - flat_input[i] = static_cast(qvalue); - } else if (dtype == at::kShort) { - flat_input[i] = static_cast(qvalue); - } else if (dtype == at::kInt) { - flat_input[i] = static_cast(qvalue); - } else if (dtype == at::kLong) { - flat_input[i] = static_cast(qvalue); - } - } - - // Reshape back to original dimensions - input = flat_input.reshape(input_sizes_int64); - - // Create scale and zero_point as tensors (single element tensors) - at::Tensor scale_tensor = - at::tensor({scale}, at::device(at::kCPU).dtype(at::kDouble)); - at::Tensor zero_point_tensor = - at::tensor({zero_point}, at::device(at::kCPU).dtype(at::kLong)); - - // Get reference output using tensor variant - at::Tensor reference_out = - torch::executor::native::dequantize_per_tensor_tensor_args_aten( - input, - scale_tensor, - zero_point_tensor, - quant_min, - quant_max, - dtype, - out_dtype); - - // Build Vulkan dequantize_per_tensor.tensor graph - using namespace vkcompute; - - GraphConfig config; - config.set_storage_type_override(in_storage); - ComputeGraph graph(config); - - IOValueRef r_input = graph.add_input_tensor( - input.sizes().vec(), from_at_scalartype(dtype), in_storage); - - // Add scale and zero_point as tensor inputs (buffer storage, width packed) - IOValueRef r_scale = graph.add_input_tensor( - scale_tensor.sizes().vec(), - vkapi::kFloat, - utils::kBuffer, - utils::kWidthPacked); - IOValueRef r_zero_point = graph.add_input_tensor( - zero_point_tensor.sizes().vec(), - vkapi::kInt, - utils::kBuffer, - utils::kWidthPacked); - - const ValueRef r_quant_min = graph.add_scalar(quant_min); - const ValueRef r_quant_max = graph.add_scalar(quant_max); - - const ValueRef r_out = graph.add_tensor( - input.sizes().vec(), from_at_scalartype(out_dtype), out_storage); - - const ValueRef r_dtype = - graph.add_scalar(static_cast(dtype)); - const ValueRef r_out_dtype = - graph.add_scalar(static_cast(out_dtype)); - - VK_GET_OP_FN("quantized_decomposed.dequantize_per_tensor.tensor") - (graph, - { - r_input.value, - r_scale.value, - r_zero_point.value, - r_quant_min, - r_quant_max, - r_dtype, - r_out_dtype, - r_out, - }); - - ValueRef staging_out = graph.set_output_tensor(r_out); - - graph.prepare(); - graph.prepack(); - - // Run Vulkan dequantize_per_tensor.tensor - graph.copy_into_staging( - r_input.staging, input.const_data_ptr(), input.numel()); - - // Convert scale tensor to float and copy to GPU - at::Tensor scale_float = scale_tensor.to(at::kFloat); - graph.copy_into_staging( - r_scale.staging, scale_float.const_data_ptr(), scale_float.numel()); - - // Convert zero_point tensor to int and copy to GPU - at::Tensor zero_point_int = zero_point_tensor.to(at::kInt); - graph.copy_into_staging( - r_zero_point.staging, - zero_point_int.const_data_ptr(), - zero_point_int.numel()); - - graph.execute(); - - at::Tensor vk_out = at::empty_like(reference_out).contiguous(); - graph.copy_from_staging( - staging_out, vk_out.mutable_data_ptr(), vk_out.numel()); - - // Compare outputs with appropriate tolerance for half precision - bool output_correct; - if (out_dtype == at::kHalf) { - // Use higher tolerance for half precision due to limited precision - output_correct = - at::allclose(reference_out, vk_out, /*rtol=*/1e-2, /*atol=*/1e-2); - } else { - output_correct = - at::allclose(reference_out, vk_out, /*rtol=*/1e-5, /*atol=*/1e-5); - } - if (!output_correct) { - std::cout << "\n" - << "Failed with parameters: " << std::endl; - std::cout << " scale: " << scale << std::endl; - std::cout << " zero_point: " << zero_point << std::endl; - std::cout << " quant_min: " << quant_min << std::endl; - std::cout << " quant_max: " << quant_max << std::endl; - std::cout << " storage type: " - << (in_storage == vkcompute::utils::kBuffer ? "buffer" - : "texture") - << std::endl; - std::cout << " input dtype: " << dtype << std::endl; - std::cout << " output dtype: " << out_dtype << std::endl; - - std::cout << "input:" << std::endl; - std::cout << input << std::endl; - std::cout << "reference:" << std::endl; - std::cout << reference_out << std::endl; - std::cout << "vulkan:" << std::endl; - std::cout << vk_out << std::endl; - } - - ASSERT_TRUE(output_correct); -} - -TEST( - VulkanDequantizePerTensorTensorTest, - test_vulkan_dequantize_per_tensor_tensor_int8_to_float) { - if (!vkcompute::api::context() - ->adapter_ptr() - ->has_full_int8_buffers_support()) { - GTEST_SKIP(); - } - test_vulkan_dequantize_per_tensor_tensor( - {2, 3, 4}, // input sizes - 0.01, // scale - 1, // zero_point - -128, // quant_min - 127, // quant_max - at::kChar, // input dtype - at::kFloat); // output dtype -} - -TEST( - VulkanDequantizePerTensorTensorTest, - test_vulkan_dequantize_per_tensor_tensor_uint8_to_float) { - if (!vkcompute::api::context() - ->adapter_ptr() - ->has_full_int8_buffers_support()) { - GTEST_SKIP(); - } - test_vulkan_dequantize_per_tensor_tensor( - {2, 3, 4, 12}, // input sizes - 0.1, // scale - 5, // zero_point - 0, // quant_min - 255, // quant_max - at::kByte, // input dtype - at::kFloat); // output dtype -} - -TEST( - VulkanDequantizePerTensorTensorTest, - test_vulkan_dequantize_per_tensor_tensor_int32_to_float) { - if (!vkcompute::api::context() - ->adapter_ptr() - ->has_full_int8_buffers_support()) { - GTEST_SKIP(); - } - test_vulkan_dequantize_per_tensor_tensor( - {2, 3}, // input sizes - 0.01, // scale - 12, // zero_point - std::numeric_limits::min(), // quant_min - std::numeric_limits::max(), // quant_max - at::kInt, // input dtype - at::kFloat); // output dtype -} - -TEST( - VulkanDequantizePerTensorTensorTest, - test_vulkan_dequantize_per_tensor_tensor_uint8_to_half) { - if (!vkcompute::api::context() - ->adapter_ptr() - ->has_full_int8_buffers_support()) { - GTEST_SKIP(); - } - test_vulkan_dequantize_per_tensor_tensor( - {3, 4}, // input sizes - 0.3, // scale - 2, // zero_point - 0, // quant_min - 255, // quant_max - at::kByte, // input dtype - at::kHalf); // output dtype -} - -TEST( - VulkanDequantizePerTensorTensorTest, - test_vulkan_dequantize_per_tensor_tensor_int8_to_double) { - if (!vkcompute::api::context() - ->adapter_ptr() - ->has_full_int8_buffers_support()) { - GTEST_SKIP(); - } - test_vulkan_dequantize_per_tensor_tensor( - {2, 3, 4}, // input sizes - 0.03, // scale - -2, // zero_point - -128, // quant_min - 127, // quant_max - at::kChar, // input dtype - at::kDouble); // output dtype -} diff --git a/backends/vulkan/test/op_tests/quantize_test.cpp b/backends/vulkan/test/op_tests/quantize_test.cpp deleted file mode 100644 index 86eebcf9b14..00000000000 --- a/backends/vulkan/test/op_tests/quantize_test.cpp +++ /dev/null @@ -1,2188 +0,0 @@ -/* - * Copyright (c) Meta Platforms, Inc. and affiliates. - * All rights reserved. - * - * This source code is licensed under the BSD-style license found in the - * LICENSE file in the root directory of this source tree. - */ - -#include - -#include - -#include -#include -#include - -#include -#include - -#include "test_utils.h" - -#include -#include -#include - -float eps = 1e-7; - -namespace torch { -namespace executor { -namespace native { - -// Forward declarations of the functions we're testing -Tensor& quantize_per_tensor_out( - const Tensor& input, - double scale, - int64_t zero_point, - int64_t quant_min, - int64_t quant_max, - ScalarType dtype, - Tensor& out); - -Tensor& quantize_per_token_out( - const Tensor& input, - const Tensor& scale, - const Tensor& zero_point, - int64_t quant_min, - int64_t quant_max, - ScalarType dtype, - Tensor& out); - -Tensor& quantize_per_channel_out( - const Tensor& input, - const Tensor& scale, - const Tensor& zero_point, - int64_t axis, - int64_t quant_min, - int64_t quant_max, - ScalarType dtype, - Tensor& out); - -Tensor& quantize_per_tensor_tensor_args_out( - const Tensor& input, - const Tensor& scale, - const Tensor& zero_point, - int64_t quant_min, - int64_t quant_max, - ScalarType dtype, - Tensor& out); - -// Wrapper function for quantize_per_tensor_out without context -Tensor& quantize_per_tensor_out_no_context( - const Tensor& input, - double scale, - int64_t zero_point, - int64_t quant_min, - int64_t quant_max, - ScalarType dtype, - Tensor& out) { - return torch::executor::native::quantize_per_tensor_out( - input, scale, zero_point, quant_min, quant_max, dtype, out); -} - -// Wrapper function for quantize_per_token_out without context -Tensor& quantize_per_token_out_no_context( - const Tensor& input, - const Tensor& scale, - const Tensor& zero_point, - int64_t quant_min, - int64_t quant_max, - ScalarType dtype, - Tensor& out) { - return torch::executor::native::quantize_per_token_out( - input, scale, zero_point, quant_min, quant_max, dtype, out); -} - -// Wrapper function for quantize_per_channel_out without context -Tensor& quantize_per_channel_out_no_context( - const Tensor& input, - const Tensor& scale, - const Tensor& zero_point, - int64_t axis, - int64_t quant_min, - int64_t quant_max, - ScalarType dtype, - Tensor& out) { - return torch::executor::native::quantize_per_channel_out( - input, scale, zero_point, axis, quant_min, quant_max, dtype, out); -} - -// Wrapper function for quantize_per_tensor_tensor_args_out without context -Tensor& quantize_per_tensor_tensor_args_out_no_context( - const Tensor& input, - const Tensor& scale, - const Tensor& zero_point, - int64_t quant_min, - int64_t quant_max, - ScalarType dtype, - Tensor& out) { - return torch::executor::native::quantize_per_tensor_tensor_args_out( - input, scale, zero_point, quant_min, quant_max, dtype, out); -} - -// ATen wrapper for quantize_per_tensor -at::Tensor quantize_per_tensor_aten( - const at::Tensor& input, - double scale, - int64_t zero_point, - int64_t quant_min, - int64_t quant_max, - at::ScalarType dtype) { - auto out = at::empty_like(input, dtype); - ScalarType et_dtype = at_scalartype_to_et_scalartype(dtype); - - WRAP_TO_ATEN(quantize_per_tensor_out_no_context, 6) - (input, scale, zero_point, quant_min, quant_max, et_dtype, out); - return out; -} - -// ATen wrapper for quantize_per_token -at::Tensor quantize_per_token_aten( - const at::Tensor& input, - const at::Tensor& scale, - const at::Tensor& zero_point, - int64_t quant_min, - int64_t quant_max, - at::ScalarType dtype) { - auto out = at::empty_like(input, dtype); - ScalarType et_dtype = at_scalartype_to_et_scalartype(dtype); - - WRAP_TO_ATEN(quantize_per_token_out_no_context, 6) - (input, scale, zero_point, quant_min, quant_max, et_dtype, out); - return out; -} - -// ATen wrapper for quantize_per_channel -at::Tensor quantize_per_channel_aten( - const at::Tensor& input, - const at::Tensor& scale, - const at::Tensor& zero_point, - int64_t axis, - int64_t quant_min, - int64_t quant_max, - at::ScalarType dtype) { - auto out = at::empty_like(input, dtype); - ScalarType et_dtype = at_scalartype_to_et_scalartype(dtype); - - WRAP_TO_ATEN(quantize_per_channel_out_no_context, 7) - (input, scale, zero_point, axis, quant_min, quant_max, et_dtype, out); - return out; -} - -// ATen wrapper for quantize_per_tensor with tensor args -at::Tensor quantize_per_tensor_tensor_args_aten( - const at::Tensor& input, - const at::Tensor& scale, - const at::Tensor& zero_point, - int64_t quant_min, - int64_t quant_max, - at::ScalarType dtype) { - auto out = at::empty_like(input, dtype); - ScalarType et_dtype = at_scalartype_to_et_scalartype(dtype); - - WRAP_TO_ATEN(quantize_per_tensor_tensor_args_out_no_context, 6) - (input, scale, zero_point, quant_min, quant_max, et_dtype, out); - return out; -} - -} // namespace native -} // namespace executor -} // namespace torch - -void check_quantize_args( - int64_t quant_min, - int64_t quant_max, - c10::ScalarType out_dtype) { - using namespace vkcompute; - int32_t quant_min_lower_bound = 0, quant_max_upper_bound = 0; - switch (out_dtype) { - case c10::kByte: - quant_min_lower_bound = - static_cast(std::numeric_limits::min()); - quant_max_upper_bound = - static_cast(std::numeric_limits::max()); - break; - case c10::kChar: - quant_min_lower_bound = - static_cast(std::numeric_limits::min()); - quant_max_upper_bound = - static_cast(std::numeric_limits::max()); - break; - case c10::kBits16: - case c10::kUInt16: - quant_min_lower_bound = std::numeric_limits::min(); - quant_max_upper_bound = std::numeric_limits::max(); - break; - case c10::kShort: - quant_min_lower_bound = std::numeric_limits::min(); - quant_max_upper_bound = std::numeric_limits::max(); - break; - case c10::kInt: - quant_min_lower_bound = std::numeric_limits::min(); - quant_max_upper_bound = std::numeric_limits::max(); - break; - default: - VK_CHECK_COND(false, "Unsupported dtype: ", scalar_type_name(out_dtype)); - } - VK_CHECK_COND( - quant_min >= quant_min_lower_bound, - "quant_min out of bound for dtype, expected quant_min_lower_bound: ", - quant_min_lower_bound, - " actual quant_min: ", - quant_min); - - VK_CHECK_COND( - quant_max <= quant_max_upper_bound, - "quant_max out of bound for dtype, expected quant_max_upper_bound: ", - quant_max_upper_bound, - " actual quant_max: ", - quant_max); -} - -/** - * Helper function to validate quantize_per_channel arguments - * Similar to the validation in op_quantize.cpp - */ -void check_quantize_per_channel_args( - const std::vector& input_sizes, - const std::vector& scales, - const std::vector& zero_points, - int64_t axis) { - // Normalize axis - int64_t normalized_axis = axis; - if (normalized_axis < 0) { - normalized_axis += input_sizes.size(); - } - - ASSERT_GE(normalized_axis, 0) - << "axis " << axis << " is not legal, normalized axis " << normalized_axis - << " should be >= 0"; - - ASSERT_LT(normalized_axis, static_cast(input_sizes.size())) - << "axis " << axis << " is not legal, normalized axis " << normalized_axis - << " should be < input.dim() " << input_sizes.size(); - - int64_t num_channels = input_sizes[normalized_axis]; - - ASSERT_EQ(num_channels, static_cast(scales.size())) - << "Expected scales.size() to match input.size(axis) (" << num_channels - << "), but got " << scales.size(); - - ASSERT_EQ(num_channels, static_cast(zero_points.size())) - << "Expected zero_points.size() to match input.size(axis) (" - << num_channels << "), but got " << zero_points.size(); -} - -// -// Reference Implementation -// - -/* - * Reference implementation of quantize_per_tensor - */ -at::Tensor quantize_per_tensor_reference_impl( - const at::Tensor& input, - double scale, - int64_t zero_point, - int64_t quant_min, - int64_t quant_max, - at::ScalarType dtype) { - // Create output tensor with the target dtype - at::Tensor out = at::empty_like(input, dtype); - - // Quantize the input tensor - float inv_scale = 1.0 / scale; - - // Iterate through the tensor and quantize each element - at::Tensor float_input = input.to(at::kFloat); - at::Tensor float_values = float_input.flatten(); - - auto out_flat = out.flatten(); - - for (int i = 0; i < float_values.numel(); i++) { - float value = float_values[i].item(); - int64_t qvalue = zero_point + std::nearbyint(inv_scale * value); - - qvalue = std::max(qvalue, quant_min); - qvalue = std::min(qvalue, quant_max); - - if (dtype == at::kByte) { - out_flat[i] = static_cast(qvalue); - } else if (dtype == at::kChar) { - out_flat[i] = static_cast(qvalue); - } else if (dtype == at::kShort) { - out_flat[i] = static_cast(qvalue); - } else if (dtype == at::kInt) { - out_flat[i] = static_cast(qvalue); - } else if (dtype == at::kLong) { - out_flat[i] = static_cast(qvalue); - } - } - - return out.reshape(input.sizes()); -} - -/* - * Reference implementation of quantize_per_token - */ -at::Tensor quantize_per_token_reference_impl( - const at::Tensor& input, - const at::Tensor& scale, - const at::Tensor& zero_point, - int64_t quant_min, - int64_t quant_max, - at::ScalarType dtype) { - // Create output tensor with the target dtype - at::Tensor out = at::empty_like(input, dtype); - - // Calculate number of tokens - int num_tokens = 1; - for (int i = 0; i < input.dim() - 1; i++) { - num_tokens *= input.size(i); - } - - // Verify that the number of tokens matches the size of scale and zero_point - // tensors - assert(num_tokens == scale.numel()); - assert(num_tokens == zero_point.numel()); - - // Reshape input to [num_tokens, last_dim] - at::Tensor reshaped_input = input.reshape({num_tokens, input.size(-1)}); - at::Tensor reshaped_out = out.reshape({num_tokens, input.size(-1)}); - - // Quantize each token separately - for (int token_idx = 0; token_idx < num_tokens; token_idx++) { - // Use float for scale since Vulkan doesn't support double - float token_scale = scale[token_idx].item(); - // Use int for zero_point since Vulkan doesn't support int64_t - int token_zero_point = zero_point[token_idx].item(); - - float inv_scale = 1.0 / token_scale; - - // Quantize the token - for (int i = 0; i < input.size(-1); i++) { - float value = reshaped_input[token_idx][i].item(); - int qvalue = token_zero_point + std::nearbyint(inv_scale * value); - - qvalue = std::max(qvalue, quant_min); - qvalue = std::min(qvalue, quant_max); - - if (dtype == at::kByte) { - reshaped_out[token_idx][i] = static_cast(qvalue); - } else if (dtype == at::kChar) { - reshaped_out[token_idx][i] = static_cast(qvalue); - } else if (dtype == at::kShort) { - reshaped_out[token_idx][i] = static_cast(qvalue); - } else if (dtype == at::kInt) { - reshaped_out[token_idx][i] = static_cast(qvalue); - } else if (dtype == at::kLong) { - reshaped_out[token_idx][i] = static_cast(qvalue); - } - } - } - - return out; -} - -/* - * Reference implementation of quantize_per_channel - */ -at::Tensor quantize_per_channel_reference_impl( - const at::Tensor& input, - const at::Tensor& scale, - const at::Tensor& zero_point, - int64_t axis, - int64_t quant_min, - int64_t quant_max, - at::ScalarType dtype) { - // Normalize axis to handle negative values - int64_t normalized_axis = axis; - if (normalized_axis < 0) { - normalized_axis += input.dim(); - } - - // Create output tensor with the same shape as input but with target dtype - at::Tensor output = at::empty_like(input, dtype); - - // Get the number of channels along the quantization axis - int64_t num_channels = input.size(normalized_axis); - - // Calculate strides for efficient indexing - std::vector input_strides; - std::vector input_sizes; - for (int64_t i = 0; i < input.dim(); i++) { - input_sizes.push_back(input.size(i)); - input_strides.push_back(input.stride(i)); - } - - // Get data pointers - const float* input_data = input.const_data_ptr(); - const double* scale_data = scale.const_data_ptr(); - const int64_t* zero_point_data = zero_point.const_data_ptr(); - - // Iterate through all elements in the tensor - int64_t total_elements = input.numel(); - - // Helper lambda to convert flat index to multi-dimensional coordinates - auto flat_to_coords = [&](int64_t flat_idx, std::vector& coords) { - int64_t remaining = flat_idx; - for (int64_t dim = input.dim() - 1; dim >= 0; dim--) { - coords[dim] = remaining % input_sizes[dim]; - remaining /= input_sizes[dim]; - } - }; - - // Process each element - std::vector coords(input.dim()); - for (int64_t flat_idx = 0; flat_idx < total_elements; flat_idx++) { - // Convert flat index to coordinates - flat_to_coords(flat_idx, coords); - - // Get the channel index for this element - int64_t channel_idx = coords[normalized_axis]; - - // Get the quantization parameters for this channel - double channel_scale = scale_data[channel_idx]; - int64_t channel_zero_point = zero_point_data[channel_idx]; - - // Get the input value - float input_value = input_data[flat_idx]; - - // Apply quantization formula: round(input / scale) + zero_point - float inv_scale = 1.0f / static_cast(channel_scale); - int64_t quantized_value = static_cast( - static_cast(channel_zero_point) + - std::nearbyint(static_cast(inv_scale * input_value))); - - // Clamp to quantization bounds - quantized_value = std::max(quantized_value, quant_min); - quantized_value = std::min(quantized_value, quant_max); - - // Store the result based on output dtype - switch (dtype) { - case at::kByte: { - uint8_t* output_data = output.mutable_data_ptr(); - output_data[flat_idx] = static_cast(quantized_value); - break; - } - case at::kChar: { - int8_t* output_data = output.mutable_data_ptr(); - output_data[flat_idx] = static_cast(quantized_value); - break; - } - case at::kShort: { - int16_t* output_data = output.mutable_data_ptr(); - output_data[flat_idx] = static_cast(quantized_value); - break; - } - case at::kInt: { - int32_t* output_data = output.mutable_data_ptr(); - output_data[flat_idx] = static_cast(quantized_value); - break; - } - default: - assert(false && "Unsupported output dtype"); - } - } - - return output; -} - -// Forward declaration of implementation functions -void test_vulkan_quantize_per_token_impl( - const std::vector& input_sizes, - const std::vector& scales, - const std::vector& zero_points, - int64_t quant_min, - int64_t quant_max, - at::ScalarType in_dtype, - at::ScalarType dtype, - const vkcompute::utils::StorageType in_storage, - const vkcompute::utils::StorageType out_storage); - -void test_vulkan_quantize_per_channel_impl( - const std::vector& input_sizes, - const std::vector& scales, - const std::vector& zero_points, - int64_t axis, - int64_t quant_min, - int64_t quant_max, - at::ScalarType in_dtype, - at::ScalarType dtype, - const vkcompute::utils::StorageType in_storage, - const vkcompute::utils::StorageType out_storage); - -void test_vulkan_quantize_per_tensor_tensor_impl( - const std::vector& input_sizes, - float scale, - int zero_point, - int64_t quant_min, - int64_t quant_max, - at::ScalarType in_dtype, - at::ScalarType dtype, - const vkcompute::utils::StorageType in_storage, - const vkcompute::utils::StorageType out_storage); - -// Wrapper function to test both buffer and texture storage types -void test_vulkan_quantize_per_token( - const std::vector& input_sizes, - const std::vector& scales, - const std::vector& zero_points, - int64_t quant_min, - int64_t quant_max, - at::ScalarType in_dtype = at::kFloat, - at::ScalarType dtype = at::kInt) { - // Test with buffer storage - test_vulkan_quantize_per_token_impl( - input_sizes, - scales, - zero_points, - quant_min, - quant_max, - in_dtype, - dtype, - vkcompute::utils::kBuffer, - vkcompute::utils::kBuffer); - - // If the in_dtype is a double, convert to float for texture implementation - // since they don't support 64bit as inputs - if (in_dtype == at::kDouble) { - in_dtype = at::kFloat; - } - - // Test with texture storage - test_vulkan_quantize_per_token_impl( - input_sizes, - scales, - zero_points, - quant_min, - quant_max, - in_dtype, - dtype, - vkcompute::utils::kTexture3D, - vkcompute::utils::kTexture3D); -} - -// Wrapper function to test both buffer and texture storage types -void test_vulkan_quantize_per_channel( - const std::vector& input_sizes, - const std::vector& scales, - const std::vector& zero_points, - int64_t axis, - int64_t quant_min, - int64_t quant_max, - at::ScalarType in_dtype = at::kFloat, - at::ScalarType dtype = at::kInt) { - // Test with buffer storage - test_vulkan_quantize_per_channel_impl( - input_sizes, - scales, - zero_points, - axis, - quant_min, - quant_max, - in_dtype, - dtype, - vkcompute::utils::kBuffer, - vkcompute::utils::kBuffer); - - // If the in_dtype is a double, convert to float for texture implementation - // since they don't support 64bit as inputs - if (in_dtype == at::kDouble) { - in_dtype = at::kFloat; - } - - test_vulkan_quantize_per_channel_impl( - input_sizes, - scales, - zero_points, - axis, - quant_min, - quant_max, - in_dtype, - dtype, - vkcompute::utils::kTexture3D, - vkcompute::utils::kTexture3D); -} - -// Wrapper function to test both buffer and texture storage types -void test_vulkan_quantize_per_tensor_tensor( - const std::vector& input_sizes, - float scale, - int zero_point, - int64_t quant_min, - int64_t quant_max, - at::ScalarType in_dtype = at::kFloat, - at::ScalarType dtype = at::kInt) { - // Test with buffer storage - test_vulkan_quantize_per_tensor_tensor_impl( - input_sizes, - scale, - zero_point, - quant_min, - quant_max, - in_dtype, - dtype, - vkcompute::utils::kBuffer, - vkcompute::utils::kBuffer); - - // If the in_dtype is a double, convert to float for texture implementation - // since they don't support 64bit as inputs - if (in_dtype == at::kDouble) { - in_dtype = at::kFloat; - } - - // Test with texture storage - test_vulkan_quantize_per_tensor_tensor_impl( - input_sizes, - scale, - zero_point, - quant_min, - quant_max, - in_dtype, - dtype, - vkcompute::utils::kTexture3D, - vkcompute::utils::kTexture3D); -} - -void test_reference_quantize_per_tensor( - const std::vector& input_sizes, - float scale, - int zero_point, - int64_t quant_min, - int64_t quant_max, - at::ScalarType in_dtype = at::kFloat, - at::ScalarType dtype = at::kInt) { - check_quantize_args(quant_min, quant_max, dtype); - std::vector input_sizes_int64( - input_sizes.begin(), input_sizes.end()); - at::Tensor input = - at::zeros(input_sizes_int64, at::device(at::kCPU).dtype(in_dtype)); - - // Fill with a simple pattern: values from 0 to 1 in steps - float step = 1.0f / (input.numel() - 1); - auto flat_input = input.flatten(); - for (int i = 0; i < flat_input.numel(); i++) { - flat_input[i] = i * step; - } - - // Reshape back to original dimensions - input = flat_input.reshape(input_sizes_int64); - - scale = scale < eps ? eps : scale; - - // Get reference output - at::Tensor reference_out = quantize_per_tensor_reference_impl( - input, scale, zero_point, quant_min, quant_max, dtype); - - // Get implementation output - at::Tensor impl_out = torch::executor::native::quantize_per_tensor_aten( - input, scale, zero_point, quant_min, quant_max, dtype); - - // Convert to int for consistent display regardless of underlying type - at::Tensor reference_int = reference_out.to(at::kInt); - at::Tensor impl_int = impl_out.to(at::kInt); - - const bool output_correct = at::equal(reference_int, impl_int); - if (!output_correct) { - at::Tensor diffs = at::abs(reference_int - impl_int); - - std::cout << "\n" - << "Failed with parameters: " << std::endl; - std::cout << " scale: " << scale << std::endl; - std::cout << " zero_point: " << zero_point << std::endl; - std::cout << " quant_min: " << quant_min << std::endl; - std::cout << " quant_max: " << quant_max << std::endl; - - std::cout << "input:" << std::endl; - std::cout << input << std::endl; - std::cout << "reference:" << std::endl; - std::cout << reference_int << std::endl; - std::cout << "my_reference:" << std::endl; - std::cout << impl_int << std::endl; - } - - ASSERT_TRUE(output_correct); -} - -TEST( - VulkanQuantizePerTensorTest, - test_reference_quantize_per_tensor_float_to_int8) { - test_reference_quantize_per_tensor( - {2, 3, 4}, // input sizes - 0.1, // scale - 0, // zero_point - -128, // quant_min - 127, // quant_max - at::kFloat, - at::kChar); -} - -TEST( - VulkanQuantizePerTensorTest, - test_reference_quantize_per_tensor_float_to_int32) { - test_reference_quantize_per_tensor( - {2, 3, 4}, // input sizes - 0.04, // scale - 5, // zero_point - std::numeric_limits::min(), // quant_min - std::numeric_limits::max(), // quant_max - at::kFloat, - at::kInt); -} - -TEST( - VulkanQuantizePerTensorTest, - test_reference_quantize_per_tensor_half_to_uint8) { - test_reference_quantize_per_tensor( - {2, 3, 4}, // input sizes - 0.2, // scale - 2, // zero_point - 0, // quant_min - 255, // quant_max - at::kHalf, - at::kByte); -} - -TEST( - VulkanQuantizePerTensorTest, - test_reference_quantize_per_tensor_half_to_int32) { - test_reference_quantize_per_tensor( - {2, 3, 4}, // input sizes - 0.01, // scale - 1, // zero_point - std::numeric_limits::min(), // quant_min - std::numeric_limits::max(), // quant_max - at::kHalf, - at::kInt); -} - -// No Vulkan tests for quantized_decomposed.quantize_per_tensor.default -// because it is not going to be implemented in Vulkan since we will -// be handling any future calls to this op via the export stage - -void test_reference_quantize_per_token( - const std::vector& input_sizes, - const std::vector& pre_scales, - const std::vector& zero_points, - int64_t quant_min, - int64_t quant_max, - at::ScalarType in_dtype = at::kFloat, - at::ScalarType dtype = at::kInt) { - check_quantize_args(quant_min, quant_max, dtype); - std::vector input_sizes_int64( - input_sizes.begin(), input_sizes.end()); - at::Tensor input = - at::zeros(input_sizes_int64, at::device(at::kCPU).dtype(in_dtype)); - - // Fill with a simple pattern: values from 0 to 1 in steps - float step = 1.0 / (input.numel() - 1); - auto flat_input = input.flatten(); - for (int i = 0; i < flat_input.numel(); i++) { - flat_input[i] = i * step; - } - - // Reshape back to original dimensions - input = flat_input.reshape(input_sizes_int64); - - // Calculate number of tokens - int num_tokens = 1; - for (int i = 0; i < input.dim() - 1; i++) { - num_tokens *= input.size(i); - } - - // Verify that the number of tokens matches the size of scales and zero_points - ASSERT_EQ(num_tokens, pre_scales.size()); - ASSERT_EQ(num_tokens, zero_points.size()); - - std::vector scales = pre_scales; - for (auto& s : scales) { - s = s < eps ? eps : s; - } - - // Create scale and zero_point tensors - at::Tensor scale_tensor = - at::tensor(scales, at::device(at::kCPU).dtype(at::kDouble)); - at::Tensor zero_point_tensor = - at::tensor(zero_points, at::device(at::kCPU).dtype(at::kLong)); - - // Get reference output - at::Tensor reference_out = quantize_per_token_reference_impl( - input, scale_tensor, zero_point_tensor, quant_min, quant_max, dtype); - - // Get implementation output - at::Tensor impl_out = torch::executor::native::quantize_per_token_aten( - input, scale_tensor, zero_point_tensor, quant_min, quant_max, dtype); - - // Convert to int for consistent display regardless of underlying type - at::Tensor reference_int = reference_out.to(at::kInt); - at::Tensor impl_int = impl_out.to(at::kInt); - - const bool output_correct = at::equal(reference_int, impl_out); - if (!output_correct) { - std::cout << "\n" - << "Failed with parameters: " << std::endl; - std::cout << " scale(s):"; - for (size_t i = 0; i < scales.size(); i++) { - std::cout << " " << scales[i] << " "; - } - std::cout << "" << std::endl; - std::cout << " zero_point(s):"; - for (size_t i = 0; i < zero_points.size(); i++) { - std::cout << " " << zero_points[i] << " "; - } - std::cout << "" << std::endl; - std::cout << " quant_min: " << quant_min << std::endl; - std::cout << " quant_max: " << quant_max << std::endl; - - std::cout << "input:" << std::endl; - std::cout << input << std::endl; - std::cout << "reference:" << std::endl; - std::cout << reference_int << std::endl; - std::cout << "my_reference:" << std::endl; - std::cout << impl_out << std::endl; - } - - ASSERT_TRUE(output_correct); -} - -void test_vulkan_quantize_per_token_impl( - const std::vector& input_sizes, - const std::vector& pre_scales, - const std::vector& zero_points, - int64_t quant_min, - int64_t quant_max, - at::ScalarType in_dtype = at::kFloat, - at::ScalarType dtype = at::kInt, - const vkcompute::utils::StorageType in_storage = - vkcompute::utils::kTexture3D, - const vkcompute::utils::StorageType out_storage = - vkcompute::utils::kTexture3D) { - check_quantize_args(quant_min, quant_max, dtype); - int num_tokens = 1; - for (int i = 0; i < input_sizes.size() - 1; i++) { - num_tokens *= input_sizes[i]; - } - - ASSERT_EQ(num_tokens, pre_scales.size()); - ASSERT_EQ(num_tokens, zero_points.size()); - - std::vector scales = pre_scales; - for (auto& s : scales) { - s = s < eps ? eps : s; - } - - // Create input tensor with random values - std::vector input_sizes_int64( - input_sizes.begin(), input_sizes.end()); - at::Tensor input = - at::rand(input_sizes_int64, at::device(at::kCPU).dtype(in_dtype)); - at::Tensor scale_tensor = - at::tensor(scales, at::device(at::kCPU).dtype(at::kDouble)); - at::Tensor zero_point_tensor = - at::tensor(zero_points, at::device(at::kCPU).dtype(at::kLong)); - - // Get reference output to show what we would compare against - at::Tensor reference_out = torch::executor::native::quantize_per_token_aten( - input, scale_tensor, zero_point_tensor, quant_min, quant_max, dtype); - - using namespace vkcompute; - - GraphConfig config; - config.set_storage_type_override(in_storage); - ComputeGraph graph(config); - - IOValueRef r_input = graph.add_input_tensor( - input.sizes().vec(), from_at_scalartype(input.scalar_type()), in_storage); - IOValueRef r_scale = graph.add_input_tensor( - scale_tensor.sizes().vec(), - vkapi::kFloat, - utils::kBuffer, - utils::kWidthPacked); - IOValueRef r_zero_point = graph.add_input_tensor( - zero_point_tensor.sizes().vec(), - vkapi::kInt, - utils::kBuffer, - utils::kWidthPacked); - - const ValueRef r_quant_min = graph.add_scalar(quant_min); - const ValueRef r_quant_max = graph.add_scalar(quant_max); - - const ValueRef r_out = graph.add_tensor( - input.sizes().vec(), from_at_scalartype(dtype), out_storage); - - const ValueRef r_dtype = - graph.add_scalar(static_cast(dtype)); - - VK_GET_OP_FN("quantized_decomposed.quantize_per_token.default") - (graph, - { - r_input.value, - r_scale.value, - r_zero_point.value, - r_quant_min, - r_quant_max, - r_dtype, - r_out, - }); - - ValueRef staging_out = graph.set_output_tensor(r_out); - - graph.prepare(); - - graph.prepack(); - - // Copy input data to GPU - graph.copy_into_staging( - r_input.staging, input.const_data_ptr(), input.numel()); - - // Convert scale tensor to float and copy to GPU - at::Tensor scale_float = scale_tensor.to(at::kFloat); - graph.copy_into_staging( - r_scale.staging, scale_float.const_data_ptr(), scale_float.numel()); - - // Convert zero_point tensor to int and copy to GPU - at::Tensor zero_point_int = zero_point_tensor.to(at::kInt); - graph.copy_into_staging( - r_zero_point.staging, - zero_point_int.const_data_ptr(), - zero_point_int.numel()); - - // Execute the graph - graph.execute(); - - // Copy output data back to CPU - at::Tensor vk_out = at::empty_like(reference_out).contiguous(); - graph.copy_from_staging( - staging_out, vk_out.mutable_data_ptr(), vk_out.numel()); - - // Compare outputs - at::Tensor reference_int = reference_out.to(at::kInt); - at::Tensor vk_int = vk_out.to(at::kInt); - - // Tolerance is 1 to address rounding errors and fp math differences between - // CPU/GPU - const bool output_correct = - at::allclose(reference_int, vk_int, /*rtol=*/1, /*atol=*/1); - if (!output_correct) { - at::Tensor diffs = at::abs(reference_int - vk_int); - - std::cout << "\n" - << "Failed with parameters: " << std::endl; - std::cout << " scale(s):"; - for (size_t i = 0; i < scales.size(); i++) { - std::cout << " " << scales[i] << " "; - } - std::cout << "" << std::endl; - std::cout << " zero_point(s):"; - for (size_t i = 0; i < zero_points.size(); i++) { - std::cout << " " << zero_points[i] << " "; - } - std::cout << "" << std::endl; - std::cout << " quant_min: " << quant_min << std::endl; - std::cout << " quant_max: " << quant_max << std::endl; - std::cout << " storage type: " - << (in_storage == vkcompute::utils::kBuffer ? "buffer" - : "texture") - << std::endl; - - std::cout << "input:" << std::endl; - std::cout << input << std::endl; - std::cout << "reference:" << std::endl; - std::cout << reference_int << std::endl; - std::cout << "vulkan:" << std::endl; - std::cout << vk_int << std::endl; - } - - ASSERT_TRUE(output_correct); -} - -TEST( - VulkanQuantizePerTokenTest, - test_reference_quantize_per_token_float_to_int8) { - std::vector scales = {0.1, 0, 0.3, 0.1, 0.2, 0.3}; - std::vector zero_points = {1, 2, 3, 0, -1, -2}; - - test_reference_quantize_per_token( - {2, 3, 4}, // input sizes (2*3=6 tokens) - scales, - zero_points, - -128, // quant_min - 127, // quant_max - at::kFloat, - at::kChar); -} - -TEST( - VulkanQuantizePerTokenTest, - test_reference_quantize_per_token_float_to_int32) { - std::vector scales = {0.1, 0, 0.3, 0.1, 0.2, 0.3}; - std::vector zero_points = {1, 2, 3, 0, -1, -2}; - - test_reference_quantize_per_token( - {2, 3, 4}, // input sizes (2*3=6 tokens) - scales, - zero_points, - std::numeric_limits::min(), // quant_min - std::numeric_limits::max(), // quant_max - at::kFloat, - at::kInt); -} - -TEST( - VulkanQuantizePerTokenTest, - test_reference_quantize_per_token_half_to_int32) { - std::vector scales = {0.1, 0, 0.3, 0.1, 0.2, 0.3}; - std::vector zero_points = {1, 2, 3, 0, -1, -2}; - - test_reference_quantize_per_token( - {2, 3, 4}, // input sizes (2*3=6 tokens) - scales, - zero_points, - std::numeric_limits::min(), // quant_min - std::numeric_limits::max(), // quant_max - at::kHalf, - at::kInt); -} - -TEST( - VulkanQuantizePerTokenTest, - test_reference_quantize_per_token_half_to_uint8) { - std::vector scales = {0.1, 0, 0.3, 0.1, 0.2, 0.3}; - std::vector zero_points = {1, 2, 3, 0, -1, -2}; - - test_reference_quantize_per_token( - {2, 3, 4}, // input sizes (2*3=6 tokens) - scales, - zero_points, - 0, // quant_min - 255, // quant_max - at::kHalf, - at::kByte); -} - -TEST( - VulkanQuantizePerTokenTest, - test_vulkan_quantize_per_token_float_to_uint8) { - if (!vkcompute::api::context() - ->adapter_ptr() - ->has_full_int8_buffers_support()) { - GTEST_SKIP(); - } - std::vector scales = { - -0.5, -0.3, -0.2, 0, 0.1, 0.8, 0.1, 0.2, 0.3, 0.4}; - std::vector zero_points = {-8, 0, 15, 20, 19, 12, 47, 1, -50, -12}; - - test_vulkan_quantize_per_token( - {5, 2, 4}, // input sizes (5*2=10 tokens) - scales, - zero_points, - 0, // quant_min - 255, // quant_max - at::kFloat, - at::kByte); -} - -TEST(VulkanQuantizePerTokenTest, test_vulkan_quantize_per_token_float_to_int8) { - if (!vkcompute::api::context() - ->adapter_ptr() - ->has_full_int8_buffers_support()) { - GTEST_SKIP(); - } - std::vector scales = { - -0.5, -0.3, -0.2, 0, 0.1, 0.8, 0.1, 0.2, 0.3, 0.4}; - std::vector zero_points = {-8, 0, 15, 20, 19, 12, 47, 1, -50, -12}; - - test_vulkan_quantize_per_token( - {5, 2, 4}, // input sizes (5 tokens) - scales, - zero_points, - -128, // quant_min - 127, // quant_max - at::kFloat, - at::kChar); -} - -TEST( - VulkanQuantizePerTokenTest, - test_vulkan_quantize_per_token_float_to_int32) { - std::vector scales = { - -0.5, -0.3, -0.2, 0, 0.1, 0.8, 0.1, 0.2, 0.3, 0.4}; - std::vector zero_points = {-8, 0, 15, 20, 19, 12, 47, 1, -50, -12}; - - test_vulkan_quantize_per_token( - {5, 2, 4}, // input sizes (5*2=10 tokens) - scales, - zero_points, - -2147483648, // quant_min - 2147483647, // quant_max - at::kFloat, - at::kInt); -} - -TEST( - VulkanQuantizePerTokenTest, - test_vulkan_quantize_per_token_float_to_int32_small_scales) { - std::vector scales = { - 0, - 2.9387358770557188e-39f, - 1.40129846e-45f, - 1.17549435e-38f, - 0.0000000000001}; - std::vector zero_points = {20, -10, 15, 200, 50}; - - test_vulkan_quantize_per_token( - {5, 2}, // input sizes (3 tokens) - scales, - zero_points, - -2147483648, // quant_min - 2147483647, // quant_max - at::kFloat, - at::kInt); -} - -TEST( - VulkanQuantizePerTokenTest, - test_vulkan_quantize_per_token_float_to_uint8_many_tokens) { - if (!vkcompute::api::context() - ->adapter_ptr() - ->has_full_int8_buffers_support()) { - GTEST_SKIP(); - } - std::vector scales(18, 0.1); - std::vector zero_points(18, 5); - - // Alternate scale values - for (size_t i = 0; i < scales.size(); i++) { - scales[i] = (i % 2 == 0) ? 0.3 : -0.5; - } - - test_vulkan_quantize_per_token( - {3, 3, 2, 3}, // input sizes (3*3*2=18 tokens) - scales, - zero_points, - 0, // quant_min - 125, // quant_max - at::kFloat, - at::kByte); -} - -TEST(VulkanQuantizePerTokenTest, test_vulkan_quantize_per_token_half_to_int8) { - if (!vkcompute::api::context() - ->adapter_ptr() - ->has_full_float16_buffers_support()) { - GTEST_SKIP(); - } - std::vector scales = {0.1, 0.2}; - std::vector zero_points = {0, 5}; - - test_vulkan_quantize_per_token( - {2, 2}, // input sizes (2*2=4 tokens) - scales, - zero_points, - -128, // quant_min - 127, // quant_max - at::kHalf, // input dtype - at::kChar); // output dtype -} - -TEST( - VulkanQuantizePerTokenTest, - test_vulkan_quantize_per_token_double_to_int8) { - if (!vkcompute::api::context() - ->adapter_ptr() - ->has_full_int8_buffers_support()) { - GTEST_SKIP(); - } - std::vector scales = {0.1, 0.2}; - std::vector zero_points = {0, 5}; - - test_vulkan_quantize_per_token( - {2, 2}, // input sizes (2*2=4 tokens) - scales, - zero_points, - -128, // quant_min - 127, // quant_max - at::kDouble, // input dtype - at::kChar); // output dtype -} - -void test_reference_quantize_per_channel( - const std::vector& input_sizes, - const std::vector& pre_scales, - const std::vector& zero_points, - int64_t axis, - int64_t quant_min, - int64_t quant_max, - at::ScalarType in_dtype = at::kFloat, - at::ScalarType dtype = at::kInt) { - check_quantize_args(quant_min, quant_max, dtype); - check_quantize_per_channel_args(input_sizes, pre_scales, zero_points, axis); - - std::vector input_sizes_int64( - input_sizes.begin(), input_sizes.end()); - at::Tensor input = - at::zeros(input_sizes_int64, at::device(at::kCPU).dtype(in_dtype)); - - // Fill with a simple pattern: values from 0 to 1 in steps - float step = 1.0f / (input.numel() - 1); - auto flat_input = input.flatten(); - for (int i = 0; i < flat_input.numel(); i++) { - flat_input[i] = i * step; - } - - // Reshape back to original dimensions - input = flat_input.reshape(input_sizes_int64); - - std::vector scales = pre_scales; - for (auto& s : scales) { - s = s < eps ? eps : s; - } - - // Create scale and zero_point tensors - at::Tensor scale_tensor = - at::tensor(scales, at::device(at::kCPU).dtype(at::kDouble)); - at::Tensor zero_point_tensor = - at::tensor(zero_points, at::device(at::kCPU).dtype(at::kLong)); - - // Get reference output - at::Tensor my_ref = quantize_per_channel_reference_impl( - input, - scale_tensor, - zero_point_tensor, - axis, - quant_min, - quant_max, - dtype); - - // Get implementation output - at::Tensor cpu_ref = torch::executor::native::quantize_per_channel_aten( - input, - scale_tensor, - zero_point_tensor, - axis, - quant_min, - quant_max, - dtype); - - // Get direct ATen implementation output - c10::ScalarType aten_dtype = dtype; - if (dtype == at::kChar) { - aten_dtype = c10::kQInt8; - } else if (dtype == at::kByte) { - aten_dtype = c10::kQUInt8; - } - - // Normalize axis for ATen (it doesn't handle negative values) - int64_t normalized_axis = axis; - if (normalized_axis < 0) { - normalized_axis += input.dim(); - } - - at::Tensor aten_ref = at::quantize_per_channel( - input, scale_tensor, zero_point_tensor, normalized_axis, aten_dtype); - - // Convert to int for consistent display regardless of underlying type - at::Tensor my_ref_int = my_ref.to(at::kInt); - at::Tensor cpu_ref_int = cpu_ref.to(at::kInt); - // For quantized tensors, we need to use int_repr() to get the underlying - // integer values - at::Tensor aten_ref_int = aten_ref.int_repr().to(at::kInt); - - const bool output_correct = at::equal(my_ref_int, cpu_ref_int); - if (!output_correct) { - std::cout << "\n" - << "Failed with parameters: " << std::endl; - std::cout << " axis: " << axis << std::endl; - std::cout << " input sizes:"; - for (size_t i = 0; i < input_sizes.size(); i++) { - std::cout << " " << input_sizes[i] << " "; - } - std::cout << "" << std::endl; - std::cout << " scale(s):"; - for (size_t i = 0; i < scales.size(); i++) { - std::cout << " " << scales[i] << " "; - } - std::cout << "" << std::endl; - std::cout << " zero_point(s):"; - for (size_t i = 0; i < zero_points.size(); i++) { - std::cout << " " << zero_points[i] << " "; - } - std::cout << "" << std::endl; - std::cout << " quant_min: " << quant_min << std::endl; - std::cout << " quant_max: " << quant_max << std::endl; - - std::cout << "input:" << std::endl; - std::cout << input << std::endl; - std::cout << "aten_ref:" << std::endl; - std::cout << aten_ref_int << std::endl; - std::cout << "cpu_ref:" << std::endl; - std::cout << cpu_ref_int << std::endl; - std::cout << "my_ref:" << std::endl; - std::cout << my_ref_int << std::endl; - } - - ASSERT_TRUE(output_correct); -} - -void test_vulkan_quantize_per_channel_impl( - const std::vector& input_sizes, - const std::vector& pre_scales, - const std::vector& zero_points, - int64_t axis, - int64_t quant_min, - int64_t quant_max, - at::ScalarType in_dtype = at::kFloat, - at::ScalarType dtype = at::kInt, - const vkcompute::utils::StorageType in_storage = - vkcompute::utils::kTexture3D, - const vkcompute::utils::StorageType out_storage = - vkcompute::utils::kTexture3D) { - check_quantize_args(quant_min, quant_max, dtype); - check_quantize_per_channel_args(input_sizes, pre_scales, zero_points, axis); - - std::vector scales = pre_scales; - for (auto& s : scales) { - s = s < eps ? eps : s; - } - - // Create input tensor with random values - std::vector input_sizes_int64( - input_sizes.begin(), input_sizes.end()); - at::Tensor input = - at::rand(input_sizes_int64, at::device(at::kCPU).dtype(in_dtype)); - at::Tensor scale_tensor = - at::tensor(scales, at::device(at::kCPU).dtype(at::kDouble)); - at::Tensor zero_point_tensor = - at::tensor(zero_points, at::device(at::kCPU).dtype(at::kLong)); - - // Get reference output - at::Tensor reference_out = torch::executor::native::quantize_per_channel_aten( - input, - scale_tensor, - zero_point_tensor, - axis, - quant_min, - quant_max, - dtype); - - using namespace vkcompute; - - GraphConfig config; - config.set_storage_type_override(in_storage); - ComputeGraph graph(config); - - IOValueRef r_input = graph.add_input_tensor( - input.sizes().vec(), from_at_scalartype(input.scalar_type()), in_storage); - IOValueRef r_scale = graph.add_input_tensor( - scale_tensor.sizes().vec(), - vkapi::kFloat, - utils::kBuffer, - utils::kWidthPacked); - IOValueRef r_zero_point = graph.add_input_tensor( - zero_point_tensor.sizes().vec(), - vkapi::kInt, - utils::kBuffer, - utils::kWidthPacked); - - const ValueRef r_axis = graph.add_scalar(axis); - const ValueRef r_quant_min = graph.add_scalar(quant_min); - const ValueRef r_quant_max = graph.add_scalar(quant_max); - - const ValueRef r_out = graph.add_tensor( - input.sizes().vec(), from_at_scalartype(dtype), out_storage); - - const ValueRef r_dtype = - graph.add_scalar(static_cast(dtype)); - - VK_GET_OP_FN("quantized_decomposed.quantize_per_channel.default") - (graph, - { - r_input.value, - r_scale.value, - r_zero_point.value, - r_axis, - r_quant_min, - r_quant_max, - r_dtype, - r_out, - }); - - ValueRef staging_out = graph.set_output_tensor(r_out); - - graph.prepare(); - graph.prepack(); - - // Copy input data to GPU - graph.copy_into_staging( - r_input.staging, input.const_data_ptr(), input.numel()); - - // Convert scale tensor to float and copy to GPU - at::Tensor scale_float = scale_tensor.to(at::kFloat); - graph.copy_into_staging( - r_scale.staging, scale_float.const_data_ptr(), scale_float.numel()); - - // Convert zero_point tensor to int and copy to GPU - at::Tensor zero_point_int = zero_point_tensor.to(at::kInt); - graph.copy_into_staging( - r_zero_point.staging, - zero_point_int.const_data_ptr(), - zero_point_int.numel()); - - // Execute the graph - graph.execute(); - - // Copy output data back to CPU - at::Tensor vk_out = at::empty_like(reference_out).contiguous(); - graph.copy_from_staging( - staging_out, vk_out.mutable_data_ptr(), vk_out.numel()); - - // Compare outputs - at::Tensor reference_int = reference_out.to(at::kInt); - at::Tensor vk_int = vk_out.to(at::kInt); - - // Tolerance is 1 to address rounding errors and fp math differences between - // CPU/GPU - const bool output_correct = - at::allclose(reference_int, vk_int, /*rtol=*/1, /*atol=*/1); - if (!output_correct) { - at::Tensor diffs = at::abs(reference_int - vk_int); - - std::cout << "\n" - << "Failed with parameters: " << std::endl; - std::cout << " axis: " << axis << std::endl; - std::cout << " input sizes:"; - for (size_t i = 0; i < input_sizes.size(); i++) { - std::cout << " " << input_sizes[i] << " "; - } - std::cout << "" << std::endl; - std::cout << " scale(s):"; - for (size_t i = 0; i < scales.size(); i++) { - std::cout << " " << scales[i] << " "; - } - std::cout << "" << std::endl; - std::cout << " zero_point(s):"; - for (size_t i = 0; i < zero_points.size(); i++) { - std::cout << " " << zero_points[i] << " "; - } - std::cout << "" << std::endl; - std::cout << " quant_min: " << quant_min << std::endl; - std::cout << " quant_max: " << quant_max << std::endl; - std::cout << " storage type: " - << (in_storage == vkcompute::utils::kBuffer ? "buffer" - : "texture") - << std::endl; - - std::cout << "input:" << std::endl; - std::cout << input << std::endl; - std::cout << "reference:" << std::endl; - std::cout << reference_int << std::endl; - std::cout << "vulkan:" << std::endl; - std::cout << vk_int << std::endl; - } - - ASSERT_TRUE(output_correct); -} - -TEST( - VulkanQuantizePerChannelTest, - test_reference_quantize_per_channel_float_to_int8_3D_axis0) { - std::vector scales = {0.1, 0.2, 0.3}; - std::vector zero_points = {0, 5, -2}; - - test_reference_quantize_per_channel( - {3, 4, 2}, // input sizes - scales, - zero_points, - 0, // axis - -128, // quant_min - 127, // quant_max - at::kFloat, - at::kChar); -} - -TEST( - VulkanQuantizePerChannelTest, - test_reference_quantize_per_channel_float_to_int8_3D_axis2) { - std::vector scales = {0.1, 0.2}; - std::vector zero_points = {0, 5}; - - test_reference_quantize_per_channel( - {3, 4, 2}, // input sizes - scales, - zero_points, - 2, // axis - -128, // quant_min - 127, // quant_max - at::kFloat, - at::kChar); -} - -TEST( - VulkanQuantizePerChannelTest, - test_reference_quantize_per_channel_float_to_int8_3D_axisn1) { - std::vector scales = {0.1, 0.2}; - std::vector zero_points = {0, 5}; - - test_reference_quantize_per_channel( - {3, 4, 2}, // input sizes - scales, - zero_points, - -1, // axis - -128, // quant_min - 127, // quant_max - at::kFloat, - at::kChar); -} - -TEST( - VulkanQuantizePerChannelTest, - test_reference_quantize_per_channel_float_to_int8_4D_axis0) { - std::vector scales = {0.1, 0.2, 0.00002}; - std::vector zero_points = {0, 5, -4}; - - test_reference_quantize_per_channel( - {3, 4, 2, 5}, // input sizes - scales, - zero_points, - 0, // axis - -128, // quant_min - 127, // quant_max - at::kFloat, - at::kChar); -} - -// END OF REFERENCE TESTS - -TEST( - VulkanQuantizePerChannelTest, - test_vulkan_quantize_per_channel_float_to_int8_axis0) { - if (!vkcompute::api::context() - ->adapter_ptr() - ->has_full_int8_buffers_support()) { - GTEST_SKIP(); - } - std::vector scales(9, 0.1f); - std::vector zero_points(9, 2); - - // 1D Tensor - test_vulkan_quantize_per_channel( - {9}, // input sizes - scales, - zero_points, - 0, // axis - -128, // quant_min - 127, // quant_max - at::kFloat, - at::kChar); - - // 2D Tensor - test_vulkan_quantize_per_channel( - {9, 14}, // input sizes - scales, - zero_points, - 0, // axis - -128, // quant_min - 127, // quant_max - at::kFloat, - at::kChar); - - // 3D Tensor - test_vulkan_quantize_per_channel( - {9, 7, 11}, // input sizes - scales, - zero_points, - 0, // axis - -128, // quant_min - 127, // quant_max - at::kFloat, - at::kChar); - - // 4D Tensor - test_vulkan_quantize_per_channel( - {9, 17, 5, 5}, // input sizes - scales, - zero_points, - 0, // axis - -128, // quant_min - 127, // quant_max - at::kFloat, - at::kChar); - - // 4D Tensor (negative axis) - test_vulkan_quantize_per_channel( - {5, 17, 5, 9}, // input sizes - scales, - zero_points, - -1, // axis - -128, // quant_min - 127, // quant_max - at::kFloat, - at::kChar); -} - -TEST( - VulkanQuantizePerChannelTest, - test_vulkan_quantize_per_channel_float_to_int8_axis1) { - if (!vkcompute::api::context() - ->adapter_ptr() - ->has_full_int8_buffers_support()) { - GTEST_SKIP(); - } - std::vector scales(14, 0.001f); - std::vector zero_points(14, -5); - - // 2D Tensor - test_vulkan_quantize_per_channel( - {9, 14}, // input sizes - scales, - zero_points, - 1, // axis - -128, // quant_min - 127, // quant_max - at::kFloat, - at::kChar); - - // 3D Tensor - test_vulkan_quantize_per_channel( - {9, 14, 11}, // input sizes - scales, - zero_points, - 1, // axis - -128, // quant_min - 127, // quant_max - at::kFloat, - at::kChar); - - // 4D Tensor - test_vulkan_quantize_per_channel( - {9, 14, 5, 5}, // input sizes - scales, - zero_points, - 1, // axis - -128, // quant_min - 127, // quant_max - at::kFloat, - at::kChar); - - // 4D Tensor (negative axis) - test_vulkan_quantize_per_channel( - {9, 7, 14, 5}, // input sizes - scales, - zero_points, - -2, // axis - -128, // quant_min - 127, // quant_max - at::kFloat, - at::kChar); -} - -TEST( - VulkanQuantizePerChannelTest, - test_vulkan_quantize_per_channel_float_to_int8_axis2) { - if (!vkcompute::api::context() - ->adapter_ptr() - ->has_full_int8_buffers_support()) { - GTEST_SKIP(); - } - std::vector scales(11, 0.5f); - std::vector zero_points(11, 12); - - // 3D Tensor - test_vulkan_quantize_per_channel( - {9, 14, 11}, // input sizes - scales, - zero_points, - 2, // axis - -128, // quant_min - 127, // quant_max - at::kFloat, - at::kChar); - - // 4D Tensor - test_vulkan_quantize_per_channel( - {9, 14, 11, 5}, // input sizes - scales, - zero_points, - 2, // axis - -128, // quant_min - 127, // quant_max - at::kFloat, - at::kChar); - - // 4D Tensor (negative axis) - test_vulkan_quantize_per_channel( - {9, 11, 14, 5}, // input sizes - scales, - zero_points, - -3, // axis - -128, // quant_min - 127, // quant_max - at::kFloat, - at::kChar); -} - -TEST( - VulkanQuantizePerChannelTest, - test_vulkan_quantize_per_channel_float_to_int8_axis3) { - if (!vkcompute::api::context() - ->adapter_ptr() - ->has_full_int8_buffers_support()) { - GTEST_SKIP(); - } - std::vector scales(7, 0.5f); - std::vector zero_points(7, 12); - - // 4D Tensor - test_vulkan_quantize_per_channel( - {9, 14, 11, 7}, // input sizes - scales, - zero_points, - 3, // axis - -128, // quant_min - 127, // quant_max - at::kFloat, - at::kChar); - - // 4D Tensor (negative axis) - test_vulkan_quantize_per_channel( - {7, 14, 11, 7}, // input sizes - scales, - zero_points, - -4, // axis - -128, // quant_min - 127, // quant_max - at::kFloat, - at::kChar); -} - -TEST( - VulkanQuantizePerChannelTest, - test_vulkan_quantize_per_channel_float_to_uint8_comprehensive) { - if (!vkcompute::api::context() - ->adapter_ptr() - ->has_full_int8_buffers_support()) { - GTEST_SKIP(); - } - std::vector scales = {0.1, 0.2, 0.0001, 0.5, 0.02}; - std::vector zero_points = {0, 5, -5, 1, 12}; - - // 4D Tensor - test_vulkan_quantize_per_channel( - {5, 14, 11, 7}, // input sizes - scales, - zero_points, - 0, // axis - 0, // quant_min - 255, // quant_max - at::kFloat, - at::kByte); - - // 4D Tensor - test_vulkan_quantize_per_channel( - {9, 5, 11, 7}, // input sizes - scales, - zero_points, - 1, // axis - 0, // quant_min - 255, // quant_max - at::kFloat, - at::kByte); - - // 4D Tensor - test_vulkan_quantize_per_channel( - {9, 14, 5, 7}, // input sizes - scales, - zero_points, - 2, // axis - 0, // quant_min - 255, // quant_max - at::kFloat, - at::kByte); - - // 4D Tensor - test_vulkan_quantize_per_channel( - {9, 14, 11, 5}, // input sizes - scales, - zero_points, - 3, // axis - 0, // quant_min - 255, // quant_max - at::kFloat, - at::kByte); - - // 4D Tensor (negative axis) - test_vulkan_quantize_per_channel( - {5, 14, 11, 7}, // input sizes - scales, - zero_points, - -4, // axis - 0, // quant_min - 255, // quant_max - at::kFloat, - at::kByte); -} - -TEST( - VulkanQuantizePerChannelTest, - test_vulkan_quantize_per_channel_half_to_8bit) { - if (!vkcompute::api::context() - ->adapter_ptr() - ->has_full_int8_buffers_support()) { - GTEST_SKIP(); - } - if (!vkcompute::api::context() - ->adapter_ptr() - ->has_full_float16_buffers_support()) { - GTEST_SKIP(); - } - std::vector scales = {0.1, 0.2, 0.01, 0.5, 0.02}; - std::vector zero_points = {0, 5, 5, 1, 12}; - - // 4D Tensor - test_vulkan_quantize_per_channel( - {5, 14, 11, 7}, // input sizes - scales, - zero_points, - 0, // axis - -128, // quant_min - 127, // quant_max - at::kHalf, - at::kChar); - - // 4D Tensor - test_vulkan_quantize_per_channel( - {9, 5, 11, 7}, // input sizes - scales, - zero_points, - 1, // axis - -128, // quant_min - 127, // quant_max - at::kHalf, - at::kChar); - - // 4D Tensor - test_vulkan_quantize_per_channel( - {9, 14, 5, 7}, // input sizes - scales, - zero_points, - 2, // axis - 0, // quant_min - 255, // quant_max - at::kHalf, - at::kByte); - - // 4D Tensor - test_vulkan_quantize_per_channel( - {9, 14, 11, 5}, // input sizes - scales, - zero_points, - 3, // axis - -128, // quant_min - 127, // quant_max - at::kHalf, - at::kChar); - - // 4D Tensor (negative axis) - test_vulkan_quantize_per_channel( - {5, 14, 11, 7}, // input sizes - scales, - zero_points, - -4, // axis - 0, // quant_min - 255, // quant_max - at::kHalf, - at::kByte); -} - -TEST( - VulkanQuantizePerChannelTest, - test_vulkan_quantize_per_channel_double_to_8bit) { - if (!vkcompute::api::context() - ->adapter_ptr() - ->has_full_int8_buffers_support()) { - GTEST_SKIP(); - } - std::vector scales = {0.1, 0.2, 0.01, 0.5, 0.02}; - std::vector zero_points = {0, 5, 5, 1, 12}; - - // 4D Tensor - test_vulkan_quantize_per_channel( - {5, 14, 11, 7}, // input sizes - scales, - zero_points, - 0, // axis - -128, // quant_min - 127, // quant_max - at::kDouble, - at::kChar); - - // 4D Tensor - test_vulkan_quantize_per_channel( - {9, 5, 11, 7}, // input sizes - scales, - zero_points, - 1, // axis - -128, // quant_min - 127, // quant_max - at::kDouble, - at::kChar); - - // 4D Tensor - test_vulkan_quantize_per_channel( - {9, 14, 5, 7}, // input sizes - scales, - zero_points, - 2, // axis - 0, // quant_min - 255, // quant_max - at::kDouble, - at::kByte); - - // 4D Tensor - test_vulkan_quantize_per_channel( - {9, 14, 11, 5}, // input sizes - scales, - zero_points, - 3, // axis - -128, // quant_min - 127, // quant_max - at::kDouble, - at::kChar); - - // 4D Tensor (negative axis) - test_vulkan_quantize_per_channel( - {5, 14, 11, 7}, // input sizes - scales, - zero_points, - -4, // axis - 0, // quant_min - 255, // quant_max - at::kDouble, - at::kByte); -} - -void test_vulkan_quantize_per_tensor_tensor_impl( - const std::vector& input_sizes, - float scale, - int zero_point, - int64_t quant_min, - int64_t quant_max, - at::ScalarType in_dtype = at::kFloat, - at::ScalarType dtype = at::kInt, - const vkcompute::utils::StorageType in_storage = - vkcompute::utils::kTexture3D, - const vkcompute::utils::StorageType out_storage = - vkcompute::utils::kTexture3D) { - check_quantize_args(quant_min, quant_max, dtype); - std::vector input_sizes_int64( - input_sizes.begin(), input_sizes.end()); - at::Tensor input = - at::rand(input_sizes_int64, at::device(at::kCPU).dtype(in_dtype)); - - scale = scale < eps ? eps : scale; - - // Create scale and zero_point as tensors (single element tensors) - at::Tensor scale_tensor = - at::tensor({scale}, at::device(at::kCPU).dtype(at::kDouble)); - at::Tensor zero_point_tensor = - at::tensor({zero_point}, at::device(at::kCPU).dtype(at::kLong)); - - // Get reference output using tensor variant - at::Tensor reference_out = - torch::executor::native::quantize_per_tensor_tensor_args_aten( - input, scale_tensor, zero_point_tensor, quant_min, quant_max, dtype); - - // Build Vulkan quantize_per_tensor.tensor graph - using namespace vkcompute; - - GraphConfig config; - config.set_storage_type_override(in_storage); - ComputeGraph graph(config); - - IOValueRef r_input = graph.add_input_tensor( - input.sizes().vec(), from_at_scalartype(input.scalar_type()), in_storage); - - // Add scale and zero_point as tensor inputs (buffer storage, width packed) - IOValueRef r_scale = graph.add_input_tensor( - scale_tensor.sizes().vec(), - vkapi::kFloat, - utils::kBuffer, - utils::kWidthPacked); - IOValueRef r_zero_point = graph.add_input_tensor( - zero_point_tensor.sizes().vec(), - vkapi::kInt, - utils::kBuffer, - utils::kWidthPacked); - - const ValueRef r_quant_min = graph.add_scalar(quant_min); - const ValueRef r_quant_max = graph.add_scalar(quant_max); - - const ValueRef r_out = graph.add_tensor( - input.sizes().vec(), from_at_scalartype(dtype), out_storage); - - const ValueRef r_dtype = - graph.add_scalar(static_cast(dtype)); - - VK_GET_OP_FN("quantized_decomposed.quantize_per_tensor.tensor") - (graph, - { - r_input.value, - r_scale.value, - r_zero_point.value, - r_quant_min, - r_quant_max, - r_dtype, - r_out, - }); - - ValueRef staging_out = graph.set_output_tensor(r_out); - - graph.prepare(); - graph.prepack(); - - // Run Vulkan quantize_per_tensor.tensor - graph.copy_into_staging( - r_input.staging, input.const_data_ptr(), input.numel()); - - // Convert scale tensor to float and copy to GPU - at::Tensor scale_float = scale_tensor.to(at::kFloat); - graph.copy_into_staging( - r_scale.staging, scale_float.const_data_ptr(), scale_float.numel()); - - // Convert zero_point tensor to int and copy to GPU - at::Tensor zero_point_int = zero_point_tensor.to(at::kInt); - graph.copy_into_staging( - r_zero_point.staging, - zero_point_int.const_data_ptr(), - zero_point_int.numel()); - - graph.execute(); - - at::Tensor vk_out = at::empty_like(reference_out).contiguous(); - graph.copy_from_staging( - staging_out, vk_out.mutable_data_ptr(), vk_out.numel()); - - // Compare outputs - // For quantized types, we need to compare the actual integer values - at::Tensor reference_int = reference_out.to(at::kInt); - at::Tensor vk_int = vk_out.to(at::kInt); - - // Tolerance is 1 to address rounding errors and fp math differences between - // CPU/GPU - const bool output_correct = - at::allclose(reference_int, vk_int, /*rtol=*/1, /*atol=*/1); - if (!output_correct) { - at::Tensor diffs = at::abs(reference_int - vk_int); - - std::cout << "\n" - << "Failed with parameters: " << std::endl; - std::cout << " scale: " << scale << std::endl; - std::cout << " zero_point: " << zero_point << std::endl; - std::cout << " quant_min: " << quant_min << std::endl; - std::cout << " quant_max: " << quant_max << std::endl; - std::cout << " storage type: " - << (in_storage == vkcompute::utils::kBuffer ? "buffer" - : "texture") - << std::endl; - - std::cout << "input:" << std::endl; - std::cout << input << std::endl; - std::cout << "reference:" << std::endl; - std::cout << reference_int << std::endl; - std::cout << "vulkan:" << std::endl; - std::cout << vk_int << std::endl; - } - - ASSERT_TRUE(output_correct); -} - -TEST( - VulkanQuantizePerTensorTensorTest, - test_vulkan_quantize_per_tensor_tensor_float_to_int8) { - if (!vkcompute::api::context() - ->adapter_ptr() - ->has_full_int8_buffers_support()) { - GTEST_SKIP(); - } - test_vulkan_quantize_per_tensor_tensor( - {2, 3, 4}, // input sizes - 0.01, // scale - 1, // zero_point - -128, // quant_min - 127, // quant_max - at::kFloat, // input dtype - at::kChar); // output dtype -} - -TEST( - VulkanQuantizePerTensorTensorTest, - test_vulkan_quantize_per_tensor_tensor_float_to_uint8) { - if (!vkcompute::api::context() - ->adapter_ptr() - ->has_full_int8_buffers_support()) { - GTEST_SKIP(); - } - test_vulkan_quantize_per_tensor_tensor( - {2, 3, 4, 12}, // input sizes - 0.1, // scale - 5, // zero_point - 0, // quant_min - 255, // quant_max - at::kFloat, // input dtype - at::kByte); // output dtype -} - -TEST( - VulkanQuantizePerTensorTensorTest, - test_vulkan_quantize_per_tensor_tensor_float_to_int32) { - if (!vkcompute::api::context() - ->adapter_ptr() - ->has_full_int8_buffers_support()) { - GTEST_SKIP(); - } - test_vulkan_quantize_per_tensor_tensor( - {2, 3}, // input sizes - 0.01, // scale - 12, // zero_point - std::numeric_limits::min(), // quant_min - std::numeric_limits::max(), // quant_max - at::kFloat, // input dtype - at::kInt); // output dtype -} - -TEST( - VulkanQuantizePerTensorTensorTest, - test_vulkan_quantize_per_tensor_tensor_half_to_uint8) { - if (!vkcompute::api::context() - ->adapter_ptr() - ->has_full_int8_buffers_support()) { - GTEST_SKIP(); - } - test_vulkan_quantize_per_tensor_tensor( - {3, 4}, // input sizes - 0.3, // scale - 2, // zero_point - 0, // quant_min - 255, // quant_max - at::kHalf, // input dtype - at::kByte); // output dtype -} - -TEST( - VulkanQuantizePerTensorTensorTest, - test_vulkan_quantize_per_tensor_tensor_double_to_int8) { - if (!vkcompute::api::context() - ->adapter_ptr() - ->has_full_int8_buffers_support()) { - GTEST_SKIP(); - } - test_vulkan_quantize_per_tensor_tensor( - {2, 3, 4}, // input sizes - 0.03, // scale - -2, // zero_point - -128, // quant_min - 127, // quant_max - at::kDouble, // input dtype - at::kChar); // output dtype -} diff --git a/backends/vulkan/test/op_tests/targets.bzl b/backends/vulkan/test/op_tests/targets.bzl index b9386f92772..dae2eddf8b2 100644 --- a/backends/vulkan/test/op_tests/targets.bzl +++ b/backends/vulkan/test/op_tests/targets.bzl @@ -177,33 +177,6 @@ def define_common_targets(is_fbcode = False): "//executorch/extension/tensor:tensor", ] ) - define_test_targets( - "quantize_test", - extra_deps = [ - ":test_utils", - "//executorch/kernels/quantized/cpu:op_quantize", - "//executorch/extension/tensor:tensor", - "//executorch/extension/aten_util:aten_bridge", - ] - ) - define_test_targets( - "dequantize_test", - extra_deps = [ - ":test_utils", - "//executorch/kernels/quantized/cpu:op_dequantize", - "//executorch/extension/tensor:tensor", - "//executorch/extension/aten_util:aten_bridge", - ] - ) - define_test_targets( - "choose_qparams_test", - extra_deps = [ - ":test_utils", - "//executorch/kernels/quantized/cpu:op_choose_qparams", - "//executorch/extension/tensor:tensor", - "//executorch/extension/aten_util:aten_bridge", - ] - ) define_test_targets( "quantized_linear_test", extra_deps = [ diff --git a/backends/vulkan/test/test_vulkan_delegate.py b/backends/vulkan/test/test_vulkan_delegate.py index f38c510a8b1..03a3263c293 100644 --- a/backends/vulkan/test/test_vulkan_delegate.py +++ b/backends/vulkan/test/test_vulkan_delegate.py @@ -1957,102 +1957,6 @@ def forward(self, x): sample_inputs, ) - def test_vulkan_backend_full_quantization_workflow(self): - class FullQuantizationWorkflowModule(torch.nn.Module): - def __init__(self): - super().__init__() - - def forward(self, x): - # Step 1: Choose quantization parameters per tensor - scale, zero_point = ( - torch.ops.quantized_decomposed.choose_qparams.tensor( - x, - quant_min=-2147483648, # int32 min - quant_max=2147483647, # int32 max - eps=1e-5, - dtype=torch.int32, - ) - ) - - # Step 2: Quantize using the calculated parameters - quantized = torch.ops.quantized_decomposed.quantize_per_tensor.tensor( - x, - scale, - zero_point, - quant_min=-2147483648, # int32 min - quant_max=2147483647, # int32 max - dtype=torch.int32, - ) - - # Step 3: Dequantize back to float - dequantized = ( - torch.ops.quantized_decomposed.dequantize_per_tensor.tensor( - quantized, - scale, - zero_point, - quant_min=-2147483648, # int32 min - quant_max=2147483647, # int32 max - dtype=torch.int32, - ) - ) - - return dequantized - - full_workflow_module = FullQuantizationWorkflowModule() - sample_inputs = (torch.rand(size=(2, 3, 4), dtype=torch.float32),) - - # Use higher tolerance since quantization introduces some error - self.lower_module_and_test_output( - full_workflow_module, sample_inputs, atol=5e-3, rtol=5e-3 - ) - - def test_vulkan_backend_full_per_token_quantization_workflow(self): - class FullPerTokenQuantizationWorkflowModule(torch.nn.Module): - def __init__(self): - super().__init__() - - def forward(self, x): - # Step 1: Choose quantization parameters per token - scale, zero_point = ( - torch.ops.quantized_decomposed.choose_qparams_per_token_asymmetric.default( - x, - dtype=torch.int32, - ) - ) - - # Step 2: Quantize using the calculated parameters per token - quantized = torch.ops.quantized_decomposed.quantize_per_token.default( - x, - scale, - zero_point, - quant_min=-2147483648, # int32 min - quant_max=2147483647, # int32 max - dtype=torch.int32, - ) - - # Step 3: Dequantize back to float per token - dequantized = ( - torch.ops.quantized_decomposed.dequantize_per_token.default( - quantized, - scale, - zero_point, - quant_min=-2147483648, # int32 min - quant_max=2147483647, # int32 max - dtype=torch.int32, - output_dtype=torch.float32, - ) - ) - - return dequantized - - full_per_token_workflow_module = FullPerTokenQuantizationWorkflowModule() - sample_inputs = (torch.rand(size=(6, 4), dtype=torch.float32),) - - # Use higher tolerance since quantization introduces some error - self.lower_module_and_test_output( - full_per_token_workflow_module, sample_inputs, atol=5e-3, rtol=5e-3 - ) - def test_vulkan_backend_different_required_reprs(self): class ComplexModule(torch.nn.Module): """