diff --git a/backends/vulkan/op_registry.py b/backends/vulkan/op_registry.py index 178cc9ea08b..33ed3150535 100644 --- a/backends/vulkan/op_registry.py +++ b/backends/vulkan/op_registry.py @@ -245,9 +245,9 @@ def register_ephemeral_op(features: OpFeatures): @update_features( [ - exir_ops.edge.quantized_decomposed.quantize_per_channel.default, exir_ops.edge.quantized_decomposed.quantize_per_tensor.default, exir_ops.edge.quantized_decomposed.quantize_per_tensor.tensor, + exir_ops.edge.quantized_decomposed.quantize_per_channel.default, exir_ops.edge.quantized_decomposed.dequantize_per_tensor.default, exir_ops.edge.quantized_decomposed.dequantize_per_tensor.tensor, exir_ops.edge.quantized_decomposed.dequantize_per_channel.default, @@ -276,14 +276,32 @@ def register_quantization_op(features: OpFeatures): [ exir_ops.edge.torchao.quantize_affine.default, exir_ops.edge.torchao.dequantize_affine.default, + ] +) +def register_affine_quantization_op(features: OpFeatures): + features.texture_impl = TextureImplFeatures( + uses_axis_map=False, + valid_packed_dims={PackedDim.WIDTH}, + ) + features.buffer_impl = True + features.resize_fn = True + features.optimal_storage = VkStorageType.TEXTURE_3D + features.optimal_layout = VkMemoryLayout.TENSOR_WIDTH_PACKED + features.handles_own_prepacking = True + + return features + + +@update_features( + [ exir_ops.edge.torchao.choose_qparams_affine.default, ] ) -def register_torchao_quantization_op(features: OpFeatures): - # TorchAO quantization operators - default to per-tensor behavior - # Same features as standard quantization ops +def register_choose_qparams_affine_op(features: OpFeatures): + # Currently only created a rudimentary buffer implementation for choose_qparams_affine + # since the reduction logic for blocks in texture3d is not trivial to implement in vulkan. features.texture_impl = TextureImplFeatures( - uses_axis_map=True, + uses_axis_map=False, valid_packed_dims={ PackedDim.WIDTH, }, @@ -292,37 +310,6 @@ def register_torchao_quantization_op(features: OpFeatures): features.resize_fn = True features.optimal_storage = VkStorageType.BUFFER - def check_torchao_quantization_node(node: torch.fx.Node) -> bool: - # Only per-tensor quantization is supported by the Vulkan backend. - if len(node.args) < 2: - return False - - block_size = node.args[1] - - if not isinstance(block_size, (list, tuple)): - return False - - input_arg = node.args[0] - if not isinstance(input_arg, torch.fx.Node): - return False - - input_tensor = input_arg.meta.get("val", None) - if not isinstance(input_tensor, FakeTensor): - return False - - input_shape = list(input_tensor.shape) - - if len(block_size) != len(input_shape): - return False - - # Check if block_size matches input_shape exactly (per-tensor quantization) - for i in range(len(block_size)): - if block_size[i] != input_shape[i]: - return False - - return True - - features.check_node_fn = check_torchao_quantization_node return features diff --git a/backends/vulkan/runtime/graph/ops/glsl/choose_qparams.glslh b/backends/vulkan/runtime/graph/ops/glsl/choose_qparams.glslh index d6d27d2e3a3..cfe5baa9c1d 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/choose_qparams.glslh +++ b/backends/vulkan/runtime/graph/ops/glsl/choose_qparams.glslh @@ -9,59 +9,67 @@ #ifndef CHOOSE_QPARAMS_GLSLH #define CHOOSE_QPARAMS_GLSLH -// Calculate scale and zero point from min and max values -void calculate_scale_and_zero_point( - float min_val, - float max_val, - int qmin, - int qmax, - float eps_threshold, - out float scale_val, - out int zero_point_val) { - // ensure we have zero included in our range - min_val = min(min_val, 0.0); - max_val = max(max_val, 0.0); +// mapping_type : 0 = ASYM, 1 = SYM, 2 = SYM_NO_CLIP +void calc_scale_zp( + float lo, float hi, + int qmin, int qmax, + int mapping_type, + float eps, + out float scale, out int zp) { + // Handle case where lo and hi are +/-INF (no valid values found) + if (isinf(lo) || isinf(hi)) { + lo = 0.0; + hi = 0.0; + } - scale_val = (max_val - min_val) / float(qmax - qmin); + float minv = min(lo, 0.0); + float maxv = max(hi, 0.0); - // Handle zero or very small scale - if (scale_val == 0.0 || isinf(1.0 / scale_val)) { - scale_val = 0.1; - } + if (mapping_type == 0) { // asymmetric + scale = (maxv - minv) / float(qmax - qmin); + + // Handle zero or very small scale + if (scale == 0.0 || isinf(1.0/scale)) { + scale = eps; + } - // Cut off small scale using the provided eps threshold - if (scale_val < eps_threshold) { - float org_scale = scale_val; - scale_val = eps_threshold; + if (scale < eps) { + float org_scale = scale; + scale = eps; - // Adjust min and max based on new scale - if (min_val == 0.0) { - max_val = eps_threshold * float(qmax - qmin); - } else if (max_val == 0.0) { - min_val = -eps_threshold * float(qmax - qmin); - } else { - float amplifier = eps_threshold / org_scale; - min_val *= amplifier; - max_val *= amplifier; + // Adjust min and max based on new scale to maintain proper quantization range + if (minv == 0.0) { + maxv = eps * float(qmax - qmin); + } else if (maxv == 0.0) { + minv = -eps * float(qmax - qmin); + } else { + float amplifier = eps / org_scale; + minv *= amplifier; + maxv *= amplifier; + } + } + + // Calculate zero_point (matching reference implementation) + float initial_zero_point = float(qmin) - round(minv / scale); + zp = int(clamp(initial_zero_point, float(qmin), float(qmax))); + } else { // symmetric -- centred + float scale_sym; + if (mapping_type == 1) { // SYM + float M = max(abs(minv), abs(maxv)); + scale_sym = M / (float(qmax - qmin) * 0.5); + } else { // SYM_NO_CLIP + float smin = abs(minv) / max(abs(float(qmin)), 1.0); // Avoid division by zero + float smax = maxv / max(float(qmax), 1.0); // Avoid division by zero + scale_sym = max(smin, smax); } - } - // Calculate zero point - float zero_point_from_min = float(qmin) - min_val / scale_val; - float zero_point_from_max = float(qmax) - max_val / scale_val; - float zero_point_from_min_error = abs(float(qmin)) - abs(min_val / scale_val); - float zero_point_from_max_error = abs(float(qmax)) - abs(max_val / scale_val); - float initial_zero_point = zero_point_from_min_error < zero_point_from_max_error - ? zero_point_from_min - : zero_point_from_max; + // Handle zero or very small scale + if (scale_sym == 0.0 || isinf(1.0/scale_sym)) { + scale_sym = eps; + } - // Nudge zero point to integer - if (initial_zero_point < float(qmin)) { - zero_point_val = qmin; - } else if (initial_zero_point > float(qmax)) { - zero_point_val = qmax; - } else { - zero_point_val = int(round(initial_zero_point)); + scale = max(scale_sym, eps); + zp = int((qmax + qmin + 1) >> 1); // mid-point – always fits } } diff --git a/backends/vulkan/runtime/graph/ops/glsl/choose_qparams_buffer.glsl b/backends/vulkan/runtime/graph/ops/glsl/choose_qparams_buffer.glsl index 48681a46c30..99a64c3589e 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/choose_qparams_buffer.glsl +++ b/backends/vulkan/runtime/graph/ops/glsl/choose_qparams_buffer.glsl @@ -31,12 +31,22 @@ $if MODE == "per_tensor": int quant_max; float eps; }; -$else: +$if MODE == "per_token": layout(push_constant) uniform restrict Block { int num_tokens; int quant_min; int quant_max; }; +$if MODE == "block_wise": + layout(push_constant) uniform BlockPC { + ivec4 blockSize; // WHCN (>=1) + ivec4 numBlocks; // #blocks along W,H,C,N + ivec4 blockStride; // {1, #W, #W * #H, #W * #H * #C} + int mapping_type; // 0=ASYM, 1=SYM, 2=SYM_NO_CLIP + int quant_min; + int quant_max; + float eps; + }; ${layout_declare_ubo(B, "ivec4", "t_in_sizes")} ${layout_declare_ubo(B, "ivec4", "t_in_strides")} @@ -57,68 +67,133 @@ shared float shared_min[NWORKERS]; shared float shared_max[NWORKERS]; /* - * QUANTIZATION PARAMETER COMPUTATION SHADER (BUFFER STORAGE) - * - * This shader computes quantization parameters (scale and zero_point) for converting - * floating-point tensors to n-bit integer representations while preserving the - * original data range as much as possible. - * - * ALGORITHM: - * 1. Find global min/max values across tensor elements using parallel reduction - * 2. Use tree reduction with shared memory for efficient min/max computation - * 3. Calculate scale = (max - min) / (quant_max - quant_min) - * 4. Calculate zero_point to map floating-point zero to integer value - * - * WORKGROUP CONFIGURATION: - * - Per-Tensor Mode: - * - Global WG Size: {1, 1, 1} (single workgroup processes entire tensor) - * - Local WG Size: {64, 1, 1} (matches NWORKERS for shared memory) - * - Per-Token Mode: - * - Global WG Size: {num_tokens, 1, 1} (one workgroup per token) - * - Local WG Size: {64, 1, 1} (matches NWORKERS for shared memory) - * - * SUPPORTED CONFIGURATIONS: - * - Buffer Storage: Uses simple linear indexing through buffer elements - * - No axis mapping or packing considerations - processes elements sequentially - * - Works with any tensor layout since it accesses buffer data linearly - * - * TREE REDUCTION VISUALIZATION FOR MIN/MAX FINDING: - * For 8 threads processing elements [10, 1, 8, 1, 0, 2, 3, 5]: - * - * Initial shared_min/shared_max arrays populated by each thread: - * shared_min: | 10 | 1 | 8 | 1 | 0 | 2 | 3 | 5 | - * shared_max: | 10 | 1 | 8 | 1 | 0 | 2 | 3 | 5 | - * Thread: | 0 | 1 | 2 | 3 | 4 | 5 | 6 | 7 | - * - * Stride 1 (compare pairs, keep min/max): - * shared_min: | 1 | | 1 | | 0 | | 3 | | (min(10,1), min(8,1), min(0,2), min(3,5)) - * shared_max: | 10 | | 8 | | 2 | | 5 | | (max(10,1), max(8,1), max(0,2), max(3,5)) - * Active: | 0 | | 2 | | 4 | | 6 | | - * - * Stride 2 (compare pairs, keep min/max): - * shared_min: | 0 | | | | 0 | | | | (min(1,1), min(0,3)) - * shared_max: | 10 | | | | 5 | | | | (max(10,8), max(2,5)) - * Active: | 0 | | | | 4 | | | | - * - * Stride 4 (final comparison): - * shared_min: | 0 | | | | | | | | (min(0,0) = 0) - * shared_max: | 10 | | | | | | | | (max(10,5) = 10) - * Active: | 0 | | | | | | | | - * - * Final result: global_min = 0, global_max = 10 (stored in shared_min[0], shared_max[0]) - * - * PER-TENSOR QUANTIZATION: - * - Single workgroup processes entire tensor with strided access - * - Each thread processes elements [thread_id, thread_id + 64, thread_id + 128, ...] - * - Tree reduction combines all thread results into global min/max - * - Output: Single scale and zero_point values - * - * PER-TOKEN QUANTIZATION: - * - Multiple workgroups, each processing one token - * - Token = all elements except last dimension (e.g., for [B,S,H]: B*S tokens of H elements) - * - Each workgroup finds min/max within its assigned token - * - Output: Array of scale and zero_point values (one per token) - */ + Quantization Parameter Computation Shader (Buffer Storage) + This shader computes quantization parameters (scale and zero_point) for converting + floating-point tensors to n-bit integer representations while preserving the + original data range as much as possible. The computed parameters enable efficient + quantization by mapping the continuous floating-point range to discrete integer values. + + Important Considerations: + (+) The input tensor is assumed to be WIDTH_PACKED (i.e., contiguous in the last dimension) + + Workgroup Configuration: + - choose_qparams_per_tensor + This mode computes a single set of quantization parameters for the entire tensor. + Uses parallel reduction across all threads to find global min/max values. + + (*) global_wg_size: {1, 1, 1} (single workgroup processes entire tensor) + (*) local_wg_size: {64, 1, 1} (matches NWORKERS for shared memory) + + - choose_qparams_per_token + This mode computes separate quantization parameters for each token in the tensor. + Each workgroup processes one token independently to find token-specific min/max. + + (*) global_wg_size: {num_tokens, 1, 1} (one workgroup per token) + (*) local_wg_size: {1, 1, 1} (single thread per token) + + - choose_qparams_block_wise + This mode computes quantization parameters for each block of elements, allowing + fine-grained control over quantization granularity within the tensor. Each block + is processed independently to find its own min/max values and compute corresponding + scale and zero_point parameters. + + (*) global_wg_size: {nBlocks, 1u, 1u} (one workgroup per block) + (*) local_wg_size: {1, 1, 1} (single thread per block) + + Block-wise quantization supports multiple mapping types for scale/zero_point calculation: + + - mapping_type = 0 (ASYMMETRIC): + Uses asymmetric quantization where the full floating-point range [min, max] is + mapped to the quantized range [quant_min, quant_max]. This preserves the original + data distribution but may not center zero optimally. + + Calculation: + scale = (max - min) / (quant_max - quant_min) + zero_point = quant_min - round(min / scale) + + Example: For range [-3.5, 10.2] mapping to int4 [-8, 7]: + scale = (10.2 - (-3.5)) / (7 - (-8)) = 13.7 / 15 = 0.913 + zero_point = -8 - round(-3.5 / 0.913) = -8 - (-4) = -4 + + - mapping_type = 1 (SYMMETRIC): + Uses symmetric quantization where the range is centered around zero. The scale + is computed based on the maximum absolute value, ensuring zero is exactly + representable in the quantized domain. + + Calculation: + max_abs = max(abs(min), abs(max)) + scale = max_abs / ((quant_max - quant_min) / 2) + zero_point = (quant_max + quant_min + 1) / 2 // midpoint + + Example: For range [-3.5, 10.2] mapping to int4 [-8, 7]: + max_abs = max(3.5, 10.2) = 10.2 + scale = 10.2 / ((7 - (-8)) / 2) = 10.2 / 7.5 = 1.36 + zero_point = (-8 + 7 + 1) / 2 = 0 + + - mapping_type = 2 (SYMMETRIC_NO_CLIPPING_ERR): + A variant of symmetric quantization that minimizes clipping errors by computing + separate scales for positive and negative ranges, then using the maximum. This + reduces quantization error on the dominant range while ensuring no values are + clipped. + + Calculation: + smin = abs(min) / abs(quant_min) // scale for negative range + smax = max / quant_max // scale for positive range + scale = max(smin, smax) // use larger scale to avoid clipping + zero_point = (quant_max + quant_min + 1) / 2 // midpoint + + Example: For range [-3.5, 10.2] mapping to int4 [-8, 7]: + smin = 3.5 / 8 = 0.4375 + smax = 10.2 / 7 = 1.457 + scale = max(0.4375, 1.457) = 1.457 // use smax to avoid clipping positives + zero_point = (-8 + 7 + 1) / 2 = 0 + + Tree Reduction Algorithm for Min/Max Finding: + The shader uses a parallel tree reduction algorithm to efficiently find minimum and + maximum values across multiple threads. This approach reduces the number of memory + accesses and synchronization points compared to sequential scanning. + + Example with 8 threads processing values [10, 1, 8, 1, 0, 2, 3, 5]: + + Step 1 - Initial Population: + Each thread loads its assigned value into shared memory arrays. + shared_min: | 10 | 1 | 8 | 1 | 0 | 2 | 3 | 5 | + shared_max: | 10 | 1 | 8 | 1 | 0 | 2 | 3 | 5 | + Thread ID: | 0 | 1 | 2 | 3 | 4 | 5 | 6 | 7 | + + Step 2 - Stride 1 (Compare Adjacent Pairs): + Threads 0,2,4,6 compare with threads 1,3,5,7 respectively. + shared_min: | 1 | | 1 | | 0 | | 3 | | (min(10,1), min(8,1), min(0,2), min(3,5)) + shared_max: | 10 | | 8 | | 2 | | 5 | | (max(10,1), max(8,1), max(0,2), max(3,5)) + Active: | 0 | | 2 | | 4 | | 6 | | + + Step 3 - Stride 2 (Compare Pairs of Pairs): + Threads 0,4 compare with threads 2,6 respectively. + shared_min: | 1 | | | | 0 | | | | (min(1,1), min(0,3)) + shared_max: | 10 | | | | 5 | | | | (max(10,8), max(2,5)) + Active: | 0 | | | | 4 | | | | + + Step 4 - Stride 4 (Final Comparison): + Thread 0 compares with thread 4 to get final result. + shared_min: | 0 | | | | | | | | (min(1,0) = 0) + shared_max: | 10 | | | | | | | | (max(10,5) = 10) + Active: | 0 | | | | | | | | + + Final Result: global_min = 0, global_max = 10 (stored in shared_min[0], shared_max[0]) + + The tree reduction completes in log_2(N) steps where N is the number of threads, + providing O(log N) time complexity instead of O(N) for sequential reduction. + + Quantization Parameter Calculation: + Once min/max values are determined, the shader computes: + - scale = (max - min) / (quant_max - quant_min) + - zero_point = quantization offset to map floating-point zero to integer range + + Mode-Specific Behavior: + - Per-Tensor: Single workgroup with strided access across entire tensor + - Per-Token: Multiple workgroups, each processing one token independently + - Block-Wise: Each thread processes assigned blocks using nested loops over block dimensions +*/ #ifdef per_tensor @@ -176,99 +251,141 @@ void choose_qparams_per_tensor() { float scale_val; int zero_point_val; - calculate_scale_and_zero_point(global_min, global_max, quant_min, quant_max, eps, scale_val, zero_point_val); + // Use default values: mapping_type=0 (ASYMMETRIC), eps from push constant + calc_scale_zp(global_min, global_max, quant_min, quant_max, 0, eps, scale_val, zero_point_val); t_scale[0] = scale_val; t_zero_point[0] = zero_point_val; } } -#else +#elif defined(per_token) void choose_qparams_per_token() { - uint global_id = gl_GlobalInvocationID.x; - uint local_id = gl_LocalInvocationID.x; - uint group_id = gl_WorkGroupID.x; - uint total_workgroups = gl_NumWorkGroups.x; - uint total_elements = uint(t_in_sizes.x * t_in_sizes.y * t_in_sizes.z * t_in_sizes.w); uint token_size = total_elements / uint(num_tokens); - // Calculate how many tokens each workgroup should process - // This handles the case where we have more tokens than workgroups - uint tokens_per_workgroup = (uint(num_tokens) + total_workgroups - 1) / total_workgroups; - - // Calculate which tokens this workgroup is responsible for - uint start_token = group_id * tokens_per_workgroup; - uint end_token = min(start_token + tokens_per_workgroup, uint(num_tokens)); + const uint TOTAL_TOKENS = uint(num_tokens); - // Early exit if this workgroup has no tokens to process - if (start_token >= uint(num_tokens)) { - return; - } - - // Process each token assigned to this workgroup - for (uint token_id = start_token; token_id < end_token; token_id++) { + /* each invocation handles token-ids: id, id+STRIDE, id+2·STRIDE … */ + const uint STRIDE = gl_WorkGroupSize.x * gl_NumWorkGroups.x; + for (uint token_id = gl_GlobalInvocationID.x; token_id < TOTAL_TOKENS; token_id += STRIDE) { // Calculate the start and end indices for this token uint token_start = token_id * token_size; uint token_end = token_start + token_size; - // Each thread processes multiple elements within the token with stride - float thread_min = 1.0/0.0; // +infinity - float thread_max = -1.0/0.0; // -infinity + // Each thread processes the entire token + float lo = 1.0/0.0; // +INF + float hi = -1.0/0.0; // -INF bool found_valid = false; - // Process elements within this token only - for (uint i = token_start + local_id; i < token_end; i += gl_WorkGroupSize.x) { + // Process all elements in this token + for (uint i = token_start; i < token_end; i++) { float val = t_in[i]; if (!isnan(val) && !isinf(val)) { if (!found_valid) { - thread_min = val; - thread_max = val; + lo = hi = val; found_valid = true; } else { - thread_min = min(thread_min, val); - thread_max = max(thread_max, val); + lo = min(lo, val); + hi = max(hi, val); } } } - // Intra-group reduction using shared memory - shared_min[local_id] = thread_min; - shared_max[local_id] = thread_max; - barrier(); + if (!found_valid) { + // If no valid values were found, use default values + lo = 0.0; + hi = 0.0; + } - // Tree reduction within work group - for (uint stride = gl_WorkGroupSize.x / 2; stride > 0; stride >>= 1) { - if (local_id < stride) { - float other_min = shared_min[local_id + stride]; - float other_max = shared_max[local_id + stride]; + // Calculate scale and zero point directly + float scale_val; + int zero_point_val; + // Use default values: mapping_type=0 (ASYMMETRIC), eps=1e-5 + calc_scale_zp(lo, hi, quant_min, quant_max, 0, 1e-5, scale_val, zero_point_val); - if (!isinf(other_min) && (isinf(shared_min[local_id]) || other_min < shared_min[local_id])) { - shared_min[local_id] = other_min; - } - if (!isinf(other_max) && (isinf(shared_max[local_id]) || other_max > shared_max[local_id])) { - shared_max[local_id] = other_max; + // Write results + t_scale[token_id] = scale_val; + t_zero_point[token_id] = zero_point_val; + } +} + +#elif defined(block_wise) + +ivec4 block_id_to_coord(uint bid) { + ivec4 bc; + bc.w = int(bid) / blockStride.w; + + int r = int(bid) - bc.w * blockStride.w; + bc.z = r / blockStride.z; + + r -= bc.z * blockStride.z; + bc.y = r / blockStride.y; + + r -= bc.y * blockStride.y; + bc.x = r; + return bc; +} + +void choose_qparams_block_wise() { + const uint TOTAL_BLOCKS = uint(numBlocks.x * numBlocks.y * numBlocks.z * numBlocks.w); + + // each invocation handles block-ids: id, id+STRIDE, id+2·STRIDE + const uint STRIDE = gl_WorkGroupSize.x * gl_NumWorkGroups.x; + for (uint block_id = gl_GlobalInvocationID.x; block_id < TOTAL_BLOCKS; block_id += STRIDE) { + // block -> WHCN coordinate + ivec4 bc = block_id_to_coord(block_id); + ivec4 blockStart = bc * blockSize; // first element (inclusive) + ivec4 blockEnd = blockStart + blockSize; // last element (exclusive) + + // min / max scan over the block + float lo = 1.0/0.0; // +INF + float hi = -1.0/0.0; // -INF + bool found_valid = false; + + // Calculate actual block dimensions + ivec4 actualBlockSize = blockEnd - blockStart; + int blockElements = actualBlockSize.x * actualBlockSize.y * actualBlockSize.z * actualBlockSize.w; + + // Linear iteration over block elements + for (int elemIdx = 0; elemIdx < blockElements; ++elemIdx) { + // Convert linear index to 4D coordinates within block + int remaining = elemIdx; + int dn = remaining / (actualBlockSize.x * actualBlockSize.y * actualBlockSize.z); + remaining -= dn * (actualBlockSize.x * actualBlockSize.y * actualBlockSize.z); + int dc = remaining / (actualBlockSize.x * actualBlockSize.y); + remaining -= dc * (actualBlockSize.x * actualBlockSize.y); + int dh = remaining / actualBlockSize.x; + int dw = remaining - dh * actualBlockSize.x; + + ivec4 tidx = blockStart + ivec4(dw, dh, dc, dn); + uint idx = tidx_to_bufi(tidx, t_in_strides); + float v = t_in[idx]; + + if (!isnan(v) && !isinf(v)) { + if (!found_valid) { + lo = hi = v; + found_valid = true; + } else { + lo = min(lo, v); + hi = max(hi, v); } } - barrier(); } - // Final calculation for this token - if (local_id == 0) { - float token_min = shared_min[0]; - float token_max = shared_max[0]; - - float scale_val; - int zero_point_val; - calculate_scale_and_zero_point(token_min, token_max, quant_min, quant_max, 1e-5, scale_val, zero_point_val); - - t_scale[token_id] = scale_val; - t_zero_point[token_id] = zero_point_val; + // Handle the case where no valid values were found in the block + if (!found_valid) { + lo = 0.0; + hi = 0.0; } - // Synchronize before processing next token - barrier(); + float scale; + int zp; + calc_scale_zp(lo, hi, quant_min, quant_max, mapping_type, eps, scale, zp); + + t_zero_point[block_id] = zp; + t_scale[block_id] = scale; } } diff --git a/backends/vulkan/runtime/graph/ops/glsl/choose_qparams_buffer.yaml b/backends/vulkan/runtime/graph/ops/glsl/choose_qparams_buffer.yaml index c37039f68e9..ee900750e16 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/choose_qparams_buffer.yaml +++ b/backends/vulkan/runtime/graph/ops/glsl/choose_qparams_buffer.yaml @@ -10,3 +10,5 @@ choose_qparams_buffer: MODE: per_tensor - NAME: choose_qparams_per_token_asymmetric_buffer MODE: per_token + - NAME: choose_qparams_block_wise_buffer + MODE: block_wise diff --git a/backends/vulkan/runtime/graph/ops/glsl/choose_qparams_texture.glsl b/backends/vulkan/runtime/graph/ops/glsl/choose_qparams_texture.glsl index 5076b2d68e9..62ea7099f8c 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/choose_qparams_texture.glsl +++ b/backends/vulkan/runtime/graph/ops/glsl/choose_qparams_texture.glsl @@ -22,8 +22,13 @@ ${define_required_extensions(IN_DTYPE)} layout(std430) buffer; -${layout_declare_tensor(B, "w", "t_scale", "float", "texture3d")} -${layout_declare_tensor(B, "w", "t_zero_point", "int", "texture3d")} +$if MODE != "block_wise": + ${layout_declare_tensor(B, "w", "t_scale", "float", "texture3d")} + ${layout_declare_tensor(B, "w", "t_zero_point", "int", "texture3d")} +$else: + ${layout_declare_tensor(B, "w", "t_scale", "float", "buffer")} + ${layout_declare_tensor(B, "w", "t_zero_point", "int", "buffer")} + ${layout_declare_tensor(B, "r", "t_in", IN_DTYPE, "texture3d")} $if MODE == "per_tensor": @@ -32,16 +37,33 @@ $if MODE == "per_tensor": int quant_max; float eps; }; -$else: +$if MODE == "per_token": layout(push_constant) uniform restrict Block { int num_tokens; int quant_min; int quant_max; }; +$if MODE == "block_wise": + layout(push_constant) uniform BlockPC { + ivec4 blockSize; // WHCN (>=1) + ivec4 numBlocks; // #blocks along W,H,C,N + ivec4 blockStride; // {1, #W, #W * #H, #W * #H * #C} + int mapping_type; // 0=ASYM, 1=SYM, 2=SYM_NO_CLIP + int quant_min; + int quant_max; + float eps; + }; ${layout_declare_ubo(B, "ivec3", "t_in_limits")} -${layout_declare_ubo(B, "ivec3", "t_scale_limits")} -${layout_declare_ubo(B, "ivec3", "t_zero_point_limits")} +$if MODE != "block_wise": + ${layout_declare_ubo(B, "ivec3", "t_scale_limits")} + ${layout_declare_ubo(B, "ivec3", "t_zero_point_limits")} +$else: + ${layout_declare_ubo(B, "ivec4", "t_scale_sizes")} + ${layout_declare_ubo(B, "ivec4", "t_scale_strides")} + ${layout_declare_ubo(B, "ivec4", "t_zero_point_sizes")} + ${layout_declare_ubo(B, "ivec4", "t_zero_point_strides")} + #include "indexing_utils.h" #include "choose_qparams.glslh" @@ -54,73 +76,87 @@ layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in; shared float shared_min[NWORKERS]; shared float shared_max[NWORKERS]; -/* - * QUANTIZATION PARAMETER COMPUTATION SHADER (TEXTURE STORAGE) - * - * This shader computes quantization parameters (scale and zero_point) for converting - * floating-point tensors to n-bit integer representations while preserving the - * original data range as much as possible. - * - * ALGORITHM: - * 1. Find global min/max values across tensor elements using parallel reduction - * 2. Use tree reduction with shared memory for efficient min/max computation - * 3. Calculate scale = (max - min) / (quant_max - quant_min) - * 4. Calculate zero_point to map floating-point zero to integer value - * - * WORKGROUP CONFIGURATION: - * - Per-Tensor Mode: - * - Global WG Size: Default (typically {num_elements, 1, 1}) - * - Local WG Size: Default (typically {64, 1, 1}) - * - Per-Token Mode: - * - Global WG Size: Default (typically based on tensor dimensions) - * - Local WG Size: Default (typically {64, 1, 1}, or based on global WG size) - * - * SUPPORTED CONFIGURATIONS: - * - Texture Storage: Uses 3D texture indexing with linear texel iteration - * - Assumes width-packed layout (packed_dim = 0) in current implementation - * - Handles texel padding for non-multiple-of-4 tensor dimensions - * - Note: Axis mapping support depends on indexing utilities - * - * TREE REDUCTION VISUALIZATION FOR MIN/MAX FINDING: - * For 8 threads processing elements [10, 1, 8, 1, 0, 2, 3, 5]: - * - * Initial shared_min/shared_max arrays populated by each thread: - * shared_min: | 10 | 1 | 8 | 1 | 0 | 2 | 3 | 5 | - * shared_max: | 10 | 1 | 8 | 1 | 0 | 2 | 3 | 5 | - * Thread: | 0 | 1 | 2 | 3 | 4 | 5 | 6 | 7 | - * - * Stride 1 (compare pairs, keep min/max): - * shared_min: | 1 | | 1 | | 0 | | 3 | | (min(10,1), min(8,1), min(0,2), min(3,5)) - * shared_max: | 10 | | 8 | | 2 | | 5 | | (max(10,1), max(8,1), max(0,2), max(3,5)) - * Active: | 0 | | 2 | | 4 | | 6 | | - * - * Stride 2 (compare pairs, keep min/max): - * shared_min: | 0 | | | | 0 | | | | (min(1,1), min(0,3)) - * shared_max: | 10 | | | | 5 | | | | (max(10,8), max(2,5)) - * Active: | 0 | | | | 4 | | | | - * - * Stride 4 (final comparison): - * shared_min: | 0 | | | | | | | | (min(0,0) = 0) - * shared_max: | 10 | | | | | | | | (max(10,5) = 10) - * Active: | 0 | | | | | | | | - * - * Final result: global_min = 0, global_max = 10 (stored in shared_min[0], shared_max[0]) - * - * PER-TENSOR QUANTIZATION: - * - Single workgroup processes entire tensor - * - Each thread processes multiple texels with stride - * - Thread 0: texels [0, 64, 128, ...] -> elements [0-3, 256-259, 512-515, ...] - * - Thread 1: texels [1, 65, 129, ...] -> elements [4-7, 260-263, 516-519, ...] - * - Tree reduction combines all thread results into global min/max - * - Output: Single scale and zero_point values - * - * PER-TOKEN QUANTIZATION: - * - Multiple workgroups, each processing subset of tokens - * - Token = all elements except last dimension (e.g., for [B,S,H]: B*S tokens of H elements) - * - Each workgroup processes multiple tokens if num_tokens > num_workgroups - * - Within each token, threads process texels containing token elements - * - Output: Array of scale and zero_point values (one per token) - */ +/*/* + Quantization Parameter Computation Shader (Buffer Storage) + This shader computes quantization parameters (scale and zero_point) for converting + floating-point tensors to n-bit integer representations while preserving the + original data range as much as possible. The computed parameters enable efficient + quantization by mapping the continuous floating-point range to discrete integer values. + + Important Considerations: + (+) The input tensor is assumed to be WIDTH_PACKED (i.e., contiguous in the last dimension) + + Workgroup Configuration: + - choose_qparams_per_tensor + This mode computes a single set of quantization parameters for the entire tensor. + Uses parallel reduction across all threads to find global min/max values. + + (*) global_wg_size: default + (*) local_wg_size: default + + - choose_qparams_per_token + This mode computes separate quantization parameters for each token in the tensor. + Each workgroup processes one token independently to find token-specific min/max. + + (*) global_wg_size: default + (*) local_wg_size: {1, 1, 1} + + - choose_qparams_block_wise + This mode computes quantization parameters for each block of elements, allowing + fine-grained control over quantization granularity within the tensor. Each block + is processed independently to find its own min/max values and compute corresponding + scale and zero_point parameters. + + NOTE: This mode currently only supports buffer storage for the output. + + (*) global_wg_size: {nBlocks, 1u, 1u} (one workgroup per block) + (*) local_wg_size: {1, 1, 1} (single thread per block) + + Tree Reduction Algorithm for Min/Max Finding: + The shader uses a parallel tree reduction algorithm to efficiently find minimum and + maximum values across multiple threads. This approach reduces the number of memory + accesses and synchronization points compared to sequential scanning. + + Example with 8 threads processing values [10, 1, 8, 1, 0, 2, 3, 5]: + + Step 1 - Initial Population: + Each thread loads its assigned value into shared memory arrays. + shared_min: | 10 | 1 | 8 | 1 | 0 | 2 | 3 | 5 | + shared_max: | 10 | 1 | 8 | 1 | 0 | 2 | 3 | 5 | + Thread ID: | 0 | 1 | 2 | 3 | 4 | 5 | 6 | 7 | + + Step 2 - Stride 1 (Compare Adjacent Pairs): + Threads 0,2,4,6 compare with threads 1,3,5,7 respectively. + shared_min: | 1 | | 1 | | 0 | | 3 | | (min(10,1), min(8,1), min(0,2), min(3,5)) + shared_max: | 10 | | 8 | | 2 | | 5 | | (max(10,1), max(8,1), max(0,2), max(3,5)) + Active: | 0 | | 2 | | 4 | | 6 | | + + Step 3 - Stride 2 (Compare Pairs of Pairs): + Threads 0,4 compare with threads 2,6 respectively. + shared_min: | 1 | | | | 0 | | | | (min(1,1), min(0,3)) + shared_max: | 10 | | | | 5 | | | | (max(10,8), max(2,5)) + Active: | 0 | | | | 4 | | | | + + Step 4 - Stride 4 (Final Comparison): + Thread 0 compares with thread 4 to get final result. + shared_min: | 0 | | | | | | | | (min(1,0) = 0) + shared_max: | 10 | | | | | | | | (max(10,5) = 10) + Active: | 0 | | | | | | | | + + Final Result: global_min = 0, global_max = 10 (stored in shared_min[0], shared_max[0]) + + The tree reduction completes in log_2(N) steps where N is the number of threads, + providing O(log N) time complexity instead of O(N) for sequential reduction. + + Quantization Parameter Calculation: + Once min/max values are determined, the shader computes: + - scale = (max - min) / (quant_max - quant_min) + - zero_point = quantization offset to map floating-point zero to integer range + + Mode-Specific Behavior: + - Per-Tensor: Single workgroup with strided access across entire tensor + - Per-Token: Multiple workgroups, each processing one token independently +*/ #ifdef per_tensor @@ -235,14 +271,14 @@ void choose_qparams_per_tensor() { float scale_val; int zero_point_val; - calculate_scale_and_zero_point(global_min, global_max, quant_min, quant_max, eps, scale_val, zero_point_val); + calc_scale_zp(global_min, global_max, quant_min, quant_max, 0, eps, scale_val, zero_point_val); write_texel(t_scale, ivec3(0, 0, 0), vec4(scale_val, 0.0, 0.0, 0.0)); write_texel(t_zero_point, ivec3(0, 0, 0), ivec4(zero_point_val, 0, 0, 0)); } } -#else +#elif defined(per_token) void choose_qparams_per_token() { // Each token is processed by multiple workgroups for parallel reduction @@ -373,7 +409,7 @@ void choose_qparams_per_token() { float scale_val; int zero_point_val; - calculate_scale_and_zero_point(token_min, token_max, quant_min, quant_max, 1e-5, scale_val, zero_point_val); + calc_scale_zp(token_min, token_max, quant_min, quant_max, 0, 1e-5, scale_val, zero_point_val); // Convert token_id to 3D coordinates for output texture // Assuming output tensors have the same layout as input but with different dimensions @@ -392,6 +428,100 @@ void choose_qparams_per_token() { } } +#elif defined(block_wise) + +ivec4 block_id_to_coord(uint bid) { + ivec4 bc; + bc.w = int(bid) / blockStride.w; + + int r = int(bid) - bc.w * blockStride.w; + bc.z = r / blockStride.z; + + r -= bc.z * blockStride.z; + bc.y = r / blockStride.y; + + r -= bc.y * blockStride.y; + bc.x = r; + return bc; +} + +void choose_qparams_block_wise() { + const uint T = uint(numBlocks.x * numBlocks.y * numBlocks.z * numBlocks.w); + const uint STRIDE = gl_WorkGroupSize.x * gl_NumWorkGroups.x; + + // tensor full size in WHCN order + const ivec4 tensorSz = blockSize * numBlocks; + + // Process blocks with stride for better parallelization + for (uint blkIdx = gl_GlobalInvocationID.x; blkIdx < T; blkIdx += STRIDE) { + // block index in WHCN + const ivec4 b4d = block_id_to_coord(blkIdx); + const ivec4 blockStart = b4d * blockSize; + const ivec4 blockEnd = blockStart + blockSize; + + // scan all elements inside the block + float vmin = 3.402823e38; // +FLT_MAX + float vmax = -3.402823e38; // -FLT_MAX + bool found_valid = false; + + // Calculate total elements in block for linear iteration + const int blockElements = blockSize.x * blockSize.y * blockSize.z * blockSize.w; + + // Linear iteration over block elements (more cache-friendly) + for (int elemIdx = 0; elemIdx < blockElements; ++elemIdx) { + // Convert linear index to 4D coordinates within block + int remaining = elemIdx; + int dn = remaining / (blockSize.x * blockSize.y * blockSize.z); + remaining -= dn * (blockSize.x * blockSize.y * blockSize.z); + int dc = remaining / (blockSize.x * blockSize.y); + remaining -= dc * (blockSize.x * blockSize.y); + int dh = remaining / blockSize.x; + int dw = remaining - dh * blockSize.x; + + ivec4 tidx = blockStart + ivec4(dw, dh, dc, dn); + + // skip padding when tensor size is not an exact multiple of block + if (any(greaterThanEqual(tidx, tensorSz))) { continue; } + + // tensor index -> (x,y,z,component) inside input texture + ivec4 posi = to_texture_elem_pos(tidx, tensorSz, 0); // 0 = W_DIM (width packed) + + // fetch texel and pick the element inside it + FVEC4_T texl = load_texel(t_in, posi.xyz); + float v; + if (posi.w == 0) v = texl.x; + else if (posi.w == 1) v = texl.y; + else if (posi.w == 2) v = texl.z; + else v = texl.w; + + if (!isnan(v) && !isinf(v)) { + if (!found_valid) { + vmin = vmax = v; + found_valid = true; + } else { + vmin = min(vmin, v); + vmax = max(vmax, v); + } + } + } + + // Handle case where no valid values were found + if (!found_valid) { + vmin = 0.0; + vmax = 0.0; + } + + // compute scale / zero‑point (same maths as buffer kernel) + float scale; + int zp; + calc_scale_zp(vmin, vmax, quant_min, quant_max, mapping_type, eps, scale, zp); + + // Write the scalar values directly to buffer using linear index + t_scale[blkIdx] = scale; + t_zero_point[blkIdx] = zp; + } +} + #endif void main() { diff --git a/backends/vulkan/runtime/graph/ops/glsl/choose_qparams_texture.yaml b/backends/vulkan/runtime/graph/ops/glsl/choose_qparams_texture.yaml index f3961b87a0f..a097ce0da48 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/choose_qparams_texture.yaml +++ b/backends/vulkan/runtime/graph/ops/glsl/choose_qparams_texture.yaml @@ -10,3 +10,5 @@ choose_qparams_texture: MODE: per_tensor - NAME: choose_qparams_per_token_asymmetric_texture3d MODE: per_token + - NAME: choose_qparams_block_wise_texture3d + MODE: block_wise diff --git a/backends/vulkan/runtime/graph/ops/impl/ChooseQParams.cpp b/backends/vulkan/runtime/graph/ops/impl/ChooseQParams.cpp index de269920eea..76d352334e3 100644 --- a/backends/vulkan/runtime/graph/ops/impl/ChooseQParams.cpp +++ b/backends/vulkan/runtime/graph/ops/impl/ChooseQParams.cpp @@ -14,45 +14,6 @@ namespace vkcompute { -namespace { - -void resize_choose_qparams_tensor_output( - ComputeGraph* graph, - const std::vector& args, - const std::vector& extra_args) { - (void)extra_args; - const ValueRef scale_out = args.at(0).refs.at(0); - const ValueRef zero_point_out = args.at(0).refs.at(1); - - // Both scale and zero_point are scalar tensors for per-tensor quantization - // Since we use single workgroup approach, no extra buffer space needed - graph->virtual_resize(scale_out, {}); - graph->virtual_resize(zero_point_out, {}); -} - -void resize_choose_qparams_per_token_output( - ComputeGraph* graph, - const std::vector& args, - const std::vector& extra_args) { - (void)extra_args; - const ValueRef scale_out = args.at(0).refs.at(0); - const ValueRef zero_point_out = args.at(0).refs.at(1); - const ValueRef input = args.at(1).refs.at(0); - - // Calculate output sizes for scale and zero_point tensors - const auto input_sizes = graph->sizes_of(input); - std::vector output_sizes; - output_sizes.reserve(input_sizes.size() - 1); - for (size_t i = 0; i < input_sizes.size() - 1; i++) { - output_sizes.push_back(input_sizes[i]); - } - output_sizes.push_back(1); - - graph->virtual_resize(scale_out, output_sizes); - graph->virtual_resize(zero_point_out, output_sizes); -} - -// Custom workgroup size pickers for ChooseQParams operations utils::uvec3 choose_qparams_pick_global_wg_size( ComputeGraph* graph, const vkapi::ShaderInfo& shader, @@ -135,15 +96,67 @@ utils::uvec3 choose_qparams_per_token_pick_local_wg_size( const ValueRef input = args.at(1).refs.at(0); if (graph->is_buffer_storage(input)) { - // For buffer storage, use 64 threads in X dimension to match NWORKERS - return {64u, 1u, 1u}; + return {1u, 1u, 1u}; } else { // For texture storage, use the default logic return graph->create_local_wg_size(global_workgroup_size); } } -} // namespace +utils::uvec3 choose_qparams_block_wise_pick_global_wg_size( + ComputeGraph* g, + const vkapi::ShaderInfo&, + const std::vector& a, + const std::vector& r) { + const ValueRef input = a.at(2).refs.at(0); + const auto blkRef = r.at(0); + const auto inSz = g->sizes_of(input); + const auto blkList = g->get_int_list(blkRef); + + // Use same code as in add_choose_qparams_block_wise_node + utils::ivec4 block_size_vec = utils::make_whcn_ivec4(*blkList); + utils::ivec4 tensor_size_whcn = utils::make_whcn_ivec4(inSz); + + // Calculate numBlocks: ceil(tensorSize / blockSize) (both in WHCN order) + utils::ivec4 nBlk = { + (tensor_size_whcn[0] + block_size_vec[0] - 1) / block_size_vec[0], + (tensor_size_whcn[1] + block_size_vec[1] - 1) / block_size_vec[1], + (tensor_size_whcn[2] + block_size_vec[2] - 1) / block_size_vec[2], + (tensor_size_whcn[3] + block_size_vec[3] - 1) / block_size_vec[3]}; + + uint32_t nBlocks = nBlk[0] * nBlk[1] * nBlk[2] * nBlk[3]; + + // For texture storage, use more threads to better utilize GPU parallelism + // Each thread can process multiple blocks with stride + if (g->is_buffer_storage(input)) { + return {nBlocks, 1u, 1u}; + } else { + // For texture storage, use more workgroups to better utilize GPU + // Aim for ~64-256 threads per workgroup for good occupancy + uint32_t preferred_threads_per_wg = 64; + uint32_t num_workgroups = + (nBlocks + preferred_threads_per_wg - 1) / preferred_threads_per_wg; + num_workgroups = std::max(1u, std::min(num_workgroups, nBlocks)); + return {num_workgroups * preferred_threads_per_wg, 1u, 1u}; + } +} + +utils::uvec3 choose_qparams_block_wise_pick_local_wg_size( + ComputeGraph* g, + const vkapi::ShaderInfo&, + const utils::uvec3& global_wg_size, + const std::vector& a, + const std::vector&) { + const ValueRef input = a.at(2).refs.at(0); + + if (g->is_buffer_storage(input)) { + return {1u, 1u, 1u}; + } else { + // For texture storage, use 64 threads per workgroup for better occupancy + uint32_t local_size = std::min(64u, global_wg_size[0]); + return {local_size, 1u, 1u}; + } +} void add_choose_qparams_tensor_node( ComputeGraph& graph, @@ -162,6 +175,7 @@ void add_choose_qparams_tensor_node( float eps_val = static_cast(graph.get_double(eps)); vkapi::ParamsBindList param_ubos; + std::vector push_constants; if (graph.is_buffer_storage(input)) { param_ubos = { @@ -178,7 +192,6 @@ void add_choose_qparams_tensor_node( graph.logical_limits_ubo(zero_point_out)}; } - std::vector push_constants; push_constants = { PushConstantDataInfo(&quant_min_val, sizeof(int)), PushConstantDataInfo(&quant_max_val, sizeof(int)), @@ -203,7 +216,7 @@ void add_choose_qparams_tensor_node( // Resize Args {}, // Resizing Logic - resize_choose_qparams_tensor_output)); + nullptr)); } void add_choose_qparams_per_token_asymmetric_node( @@ -227,6 +240,7 @@ void add_choose_qparams_per_token_asymmetric_node( int quant_max_val = 127; // Fixed for asymmetric quantization vkapi::ParamsBindList param_ubos; + std::vector push_constants; if (graph.is_buffer_storage(input)) { param_ubos = { @@ -243,7 +257,6 @@ void add_choose_qparams_per_token_asymmetric_node( graph.logical_limits_ubo(zero_point_out)}; } - std::vector push_constants; push_constants = { PushConstantDataInfo(&num_tokens_val, sizeof(int)), PushConstantDataInfo(&quant_min_val, sizeof(int)), @@ -268,7 +281,100 @@ void add_choose_qparams_per_token_asymmetric_node( // Resize Args {}, // Resizing Logic - resize_choose_qparams_per_token_output)); + nullptr)); +} + +void add_choose_qparams_block_wise_node( + ComputeGraph& graph, + ValueRef input, + ValueRef block_size, + int mapping_type, // 0 / 1 / 2 + ValueRef quant_min, + ValueRef quant_max, + ValueRef eps, + ValueRef scale_out, + ValueRef zp_out) { + const auto input_sizes = graph.sizes_of(input); + const auto block_size_list = graph.get_int_list(block_size); + + // For shader compatibility, we still need to convert to WHCN order + // but the output shape calculation is now handled correctly in resize + // function + utils::ivec4 block_size_vec = utils::make_whcn_ivec4(*block_size_list); + utils::ivec4 tensor_size_whcn = utils::make_whcn_ivec4(input_sizes); + + // Calculate numBlocks: ceil(tensorSize / blockSize) (both in WHCN order) + utils::ivec4 num_blocks_vec = { + (tensor_size_whcn[0] + block_size_vec[0] - 1) / block_size_vec[0], + (tensor_size_whcn[1] + block_size_vec[1] - 1) / block_size_vec[1], + (tensor_size_whcn[2] + block_size_vec[2] - 1) / block_size_vec[2], + (tensor_size_whcn[3] + block_size_vec[3] - 1) / block_size_vec[3]}; + + // Calculate blockStride: pre-computed linear strides for the block grid + utils::ivec4 block_stride_vec = { + 1, + num_blocks_vec[0], + num_blocks_vec[0] * num_blocks_vec[1], + num_blocks_vec[0] * num_blocks_vec[1] * num_blocks_vec[2]}; + + int qmin = static_cast(graph.get_int(quant_min)); + int qmax = static_cast(graph.get_int(quant_max)); + float eps_val = static_cast(graph.get_double(eps)); + + // Create push constants vector + std::vector push_constants = { + PushConstantDataInfo(&block_size_vec, sizeof(block_size_vec)), + PushConstantDataInfo(&num_blocks_vec, sizeof(num_blocks_vec)), + PushConstantDataInfo(&block_stride_vec, sizeof(block_stride_vec)), + PushConstantDataInfo(&mapping_type, sizeof(int)), + PushConstantDataInfo(&qmin, sizeof(int)), + PushConstantDataInfo(&qmax, sizeof(int)), + PushConstantDataInfo(&eps_val, sizeof(float))}; + + std::string kernel_name("choose_qparams_block_wise"); + add_storage_type_suffix(kernel_name, graph.storage_type_of(input)); + add_dtype_suffix(kernel_name, graph.dtype_of(input)); + + vkapi::ParamsBindList param_ubos; + + if (graph.is_buffer_storage(input)) { + param_ubos = { + graph.sizes_ubo(input), + graph.strides_ubo(input), + graph.sizes_ubo(scale_out), + graph.strides_ubo(scale_out), + graph.sizes_ubo(zp_out), + graph.strides_ubo(zp_out)}; + } else { + // For texture input, the shader uses buffer storage for outputs + // so we need buffer UBOs for the output tensors + param_ubos = { + graph.logical_limits_ubo(input), + graph.sizes_ubo(scale_out), + graph.strides_ubo(scale_out), + graph.sizes_ubo(zp_out), + graph.strides_ubo(zp_out)}; + } + + graph.execute_nodes().emplace_back(new DynamicDispatchNode( + graph, + VK_KERNEL_FROM_STR(kernel_name), + choose_qparams_block_wise_pick_global_wg_size, + choose_qparams_block_wise_pick_local_wg_size, + // Inputs and Outputs + {{scale_out, vkapi::kWrite}, + {zp_out, vkapi::kWrite}, + {input, vkapi::kRead}}, + // Shader param buffers + param_ubos, + // Push Constants + push_constants, + // Specialization Constants + {}, + // Resize Args + {block_size}, + // Resizing Logic + nullptr)); } void choose_qparams_tensor_impl( @@ -278,9 +384,8 @@ void choose_qparams_tensor_impl( const ValueRef input = args[arg_idx++]; const ValueRef quant_min = args[arg_idx++]; const ValueRef quant_max = args[arg_idx++]; - const ValueRef eps = args[arg_idx++]; // Added eps parameter (will be voided) - const ValueRef dtype = - args[arg_idx++]; // Added dtype parameter (will be voided) + const ValueRef eps = args[arg_idx++]; + const ValueRef dtype = args[arg_idx++]; const ValueRef out_tuple_ref = args[arg_idx++]; ValueRef scale_out = kDummyValueRef; @@ -301,17 +406,11 @@ void choose_qparams_tensor_impl( VK_CHECK_COND(graph.val_is_tensor(zero_point_out)); // Verify input is a floating point type - VK_CHECK_COND( - graph.dtype_of(input) == vkapi::kFloat || - graph.dtype_of(input) == vkapi::kHalf || - graph.dtype_of(input) == vkapi::kDouble); + VK_CHECK_COND(graph.dtype_of(input) == vkapi::kFloat); - // Verify output types - accept both int32 and float32 for zero_point - // TorchAO may use float32 for zero_point in some cases + // Verify output types VK_CHECK_COND(graph.dtype_of(scale_out) == vkapi::kFloat); - VK_CHECK_COND( - graph.dtype_of(zero_point_out) == vkapi::kInt || - graph.dtype_of(zero_point_out) == vkapi::kFloat); + VK_CHECK_COND(graph.dtype_of(zero_point_out) == vkapi::kInt); // Check that texture storage is width packed if (!graph.is_buffer_storage(input)) { @@ -327,8 +426,7 @@ void choose_qparams_per_token_asymmetric_impl( const std::vector& args) { int arg_idx = 0; const ValueRef input = args[arg_idx++]; - const ValueRef dtype = - args[arg_idx++]; // Added dtype parameter (will be voided) + const ValueRef dtype = args[arg_idx++]; const ValueRef out_tuple_ref = args[arg_idx++]; ValueRef scale_out = kDummyValueRef; @@ -349,17 +447,16 @@ void choose_qparams_per_token_asymmetric_impl( VK_CHECK_COND(graph.val_is_tensor(zero_point_out)); // Verify input is a floating point type - VK_CHECK_COND( - graph.dtype_of(input) == vkapi::kFloat || - graph.dtype_of(input) == vkapi::kHalf || - graph.dtype_of(input) == vkapi::kDouble); + VK_CHECK_COND(graph.dtype_of(input) == vkapi::kFloat); - // Verify output types - accept both int32 and float32 for zero_point - // TorchAO may use float32 for zero_point in some cases + // Verify output types VK_CHECK_COND(graph.dtype_of(scale_out) == vkapi::kFloat); - VK_CHECK_COND( - graph.dtype_of(zero_point_out) == vkapi::kInt || - graph.dtype_of(zero_point_out) == vkapi::kFloat); + VK_CHECK_COND(graph.dtype_of(zero_point_out) == vkapi::kInt); + + // Check that texture storage is width packed + if (!graph.is_buffer_storage(input)) { + VK_CHECK_COND(graph.packed_dim_of(input) == WHCN::kWidthDim); + } add_choose_qparams_per_token_asymmetric_node( graph, input, scale_out, zero_point_out); @@ -370,9 +467,8 @@ void choose_qparams_affine_impl( const std::vector& args) { int arg_idx = 0; const ValueRef input = args[arg_idx++]; - const ValueRef mapping_type = args[arg_idx++]; // str - ignored for per-tensor - const ValueRef block_size = - args[arg_idx++]; // SymInt[] - ignored for per-tensor + const ValueRef mapping_type = args[arg_idx++]; + const ValueRef block_size = args[arg_idx++]; const ValueRef target_dtype = args[arg_idx++]; const ValueRef quant_min = args[arg_idx++]; const ValueRef quant_max = args[arg_idx++]; @@ -382,7 +478,6 @@ void choose_qparams_affine_impl( const ValueRef out_tuple_ref = args[arg_idx++]; // Suppress unused variable warnings - (void)mapping_type; (void)target_dtype; (void)scale_dtype; (void)zero_point_dtype; @@ -402,36 +497,42 @@ void choose_qparams_affine_impl( VK_CHECK_COND(graph.val_is_tensor(zero_point_out)); // Verify input is a floating point type - VK_CHECK_COND( - graph.dtype_of(input) == vkapi::kFloat || - graph.dtype_of(input) == vkapi::kHalf || - graph.dtype_of(input) == vkapi::kDouble); + VK_CHECK_COND(graph.dtype_of(input) == vkapi::kFloat); - // Verify output types - accept both int32 and float32 for zero_point - // TorchAO may use float32 for zero_point in some cases + // Verify output types VK_CHECK_COND(graph.dtype_of(scale_out) == vkapi::kFloat); - VK_CHECK_COND( - graph.dtype_of(zero_point_out) == vkapi::kInt || - graph.dtype_of(zero_point_out) == vkapi::kFloat); + VK_CHECK_COND(graph.dtype_of(zero_point_out) == vkapi::kInt); + + // Check that texture storage is width packed + if (!graph.is_buffer_storage(input)) { + VK_CHECK_COND(graph.packed_dim_of(input) == WHCN::kWidthDim); + } - // Check if this is per-tensor quantization (only supported granularity) - // block_size should equal input tensor dimensions for per-tensor quantization const auto input_sizes = graph.sizes_of(input); const auto block_size_list = graph.get_int_list(block_size); VK_CHECK_COND(block_size_list->size() == input_sizes.size()); - for (size_t i = 0; i < input_sizes.size(); i++) { - VK_CHECK_COND((*block_size_list)[i] == input_sizes[i]); - } - // Check that texture storage is width packed - if (!graph.is_buffer_storage(input)) { - VK_CHECK_COND(graph.packed_dim_of(input) == WHCN::kWidthDim); + std::string mapping_type_str = graph.get_string(mapping_type); + int mapping_type_val = 0; // Default to ASYMMETRIC + + if (mapping_type_str == "ASYMMETRIC") { + mapping_type_val = 0; + } else if (mapping_type_str == "SYMMETRIC") { + mapping_type_val = 1; + } else if (mapping_type_str == "SYMMETRIC_NO_CLIPPING_ERR") { + mapping_type_val = 2; } - // Default to per-tensor quantization parameter calculation for TorchAO affine - // ops - add_choose_qparams_tensor_node( - graph, input, quant_min, quant_max, eps, scale_out, zero_point_out); + add_choose_qparams_block_wise_node( + graph, + input, + block_size, + mapping_type_val, + quant_min, + quant_max, + eps, + scale_out, + zero_point_out); } REGISTER_OPERATORS { diff --git a/backends/vulkan/test/op_tests/quantize_affine_test.cpp b/backends/vulkan/test/op_tests/quantize_affine_test.cpp index 8a54774d703..d2a971da82b 100644 --- a/backends/vulkan/test/op_tests/quantize_affine_test.cpp +++ b/backends/vulkan/test/op_tests/quantize_affine_test.cpp @@ -279,6 +279,134 @@ at::Tensor dequantize_affine_reference_impl( std::string("INT")); } +std::tuple choose_qparams_affine_reference_impl( + const at::Tensor& input_, + const std::string& mapping_type, + const std::vector& block_size, + int64_t quant_min, + int64_t quant_max, + double eps) { + const int64_t ndim = input_.dim(); + _check_dims("input", block_size.size(), ndim); + + VK_CHECK_COND( + input_.scalar_type() == at::kFloat || input_.scalar_type() == at::kHalf || + input_.scalar_type() == at::kBFloat16, + "Unsupported input dtype: ", + input_.dtype()); + + at::Tensor input = input_.contiguous(); + + std::vector shape_for_reduction; + std::vector reduction_dims; + int64_t cur_dim = 0; + + auto in_sizes = input.sizes(); + for (int64_t i = 0; i < ndim; ++i) { + const int64_t blk = block_size[i]; + const int64_t dim = in_sizes[i]; + + if (blk != dim && blk > 1) { + VK_CHECK_COND( + dim % blk == 0, + "Input size ", + dim, + " is not divisible by block_size ", + blk, + " at dimension ", + i); + shape_for_reduction.push_back(dim / blk); + shape_for_reduction.push_back(blk); + reduction_dims.push_back(cur_dim + 1); + cur_dim += 2; + } else { + shape_for_reduction.push_back(dim); + if (blk != 1) { + reduction_dims.push_back(cur_dim); + } + cur_dim += 1; + } + } + + at::Tensor input_reshaped = input.view(shape_for_reduction); + + std::vector shape_after_reduction = shape_for_reduction; + for (int64_t d : reduction_dims) { + shape_after_reduction[d] = 1; + } + + at::Tensor min_val = input_reshaped.amin(reduction_dims, /*keepdim=*/true); + at::Tensor max_val = input_reshaped.amax(reduction_dims, /*keepdim=*/true); + + at::Tensor scale, zero_point; + + if (mapping_type == "ASYMMETRIC") { + // Include zero in the range + min_val = at::minimum(min_val, at::zeros_like(min_val)); + max_val = at::maximum(max_val, at::zeros_like(max_val)); + + // Calculate scale + scale = (max_val - min_val) / (quant_max - quant_min); + scale = at::maximum(scale, at::full_like(scale, eps)); + + // Calculate zero_point + zero_point = at::round(quant_min - min_val / scale); + zero_point = at::clamp(zero_point, quant_min, quant_max); + } else if (mapping_type == "SYMMETRIC") { + // Include zero in the range + min_val = at::minimum(min_val, at::zeros_like(min_val)); + max_val = at::maximum(max_val, at::zeros_like(max_val)); + + // Calculate max absolute value + at::Tensor abs_min = at::abs(min_val); + at::Tensor abs_max = at::abs(max_val); + at::Tensor M = at::maximum(abs_min, abs_max); + + // Calculate scale + scale = M / ((quant_max - quant_min) * 0.5); + scale = at::maximum(scale, at::full_like(scale, eps)); + + // Calculate zero_point (mid-point) + zero_point = + at::full_like(scale, (quant_max + quant_min + 1) / 2, at::kInt); + } else if (mapping_type == "SYMMETRIC_NO_CLIPPING_ERR") { + // Include zero in the range + min_val = at::minimum(min_val, at::zeros_like(min_val)); + max_val = at::maximum(max_val, at::zeros_like(max_val)); + + // Calculate scale based on min/max values + at::Tensor s_min = at::abs(min_val) / std::abs(quant_min); + at::Tensor s_max = max_val / quant_max; + scale = at::maximum(s_min, s_max); + scale = at::maximum(scale, at::full_like(scale, eps)); + + // Calculate zero_point (mid-point) + zero_point = + at::full_like(scale, (quant_max + quant_min + 1) / 2, at::kInt); + } else { + VK_CHECK_COND( + false, + "Unsupported mapping_type: ", + mapping_type, + ". Expected ASYMMETRIC, SYMMETRIC, or SYMMETRIC_NO_CLIPPING_ERR"); + } + + std::vector output_shape; + for (size_t i = 0; i < shape_after_reduction.size(); ++i) { + if (shape_after_reduction[i] != 1 || + std::find(reduction_dims.begin(), reduction_dims.end(), i) == + reduction_dims.end()) { + output_shape.push_back(shape_after_reduction[i]); + } + } + + // Reshape scale and zero_point to final output shape + scale = scale.view(output_shape); + zero_point = zero_point.view(output_shape); + + return std::make_tuple(scale, zero_point); +} + void test_vulkan_quantize_affine_impl( const std::vector& input_sizes, const std::vector& block_size, @@ -857,3 +985,395 @@ TEST(VulkanDequantizeAffineTest, test_4d_dequantization) { at::kChar, // input dtype at::kFloat); // output dtype } + +void test_vulkan_choose_qparams_affine_impl( + const std::vector& input_sizes, + const std::vector& block_size, + const std::string& mapping_type, + int64_t quant_min, + int64_t quant_max, + double eps, + at::ScalarType in_dtype = at::kFloat, + const vkcompute::utils::StorageType in_storage = + vkcompute::utils::kTexture3D, + const vkcompute::utils::StorageType out_storage = + vkcompute::utils::kBuffer) { + // Create input tensor with random values + std::vector input_sizes_int64( + input_sizes.begin(), input_sizes.end()); + at::Tensor input = + at::rand(input_sizes_int64, at::device(at::kCPU).dtype(in_dtype)); + + // Get reference output + auto reference_out = choose_qparams_affine_reference_impl( + input, mapping_type, block_size, quant_min, quant_max, eps); + + at::Tensor reference_scale = std::get<0>(reference_out); + at::Tensor reference_zero_point = std::get<1>(reference_out); + + reference_zero_point = reference_zero_point.to(at::kInt); + + using namespace vkcompute; + + GraphConfig config; + config.set_storage_type_override(in_storage); + ComputeGraph graph(config); + + IOValueRef r_input = graph.add_input_tensor( + input.sizes().vec(), from_at_scalartype(input.scalar_type()), in_storage); + + // Create mapping_type as string + std::string mapping_type_copy = mapping_type; + const ValueRef r_mapping_type = + graph.add_string(std::move(mapping_type_copy)); + + // Create block_size as IntList + std::vector block_size_copy(block_size); + const ValueRef r_block_size = + graph.add_scalar_list(std::move(block_size_copy)); + + // Create target_dtype, quant_min, quant_max, eps + const ValueRef r_target_dtype = + graph.add_scalar(static_cast(at::kChar)); + const ValueRef r_quant_min = graph.add_scalar(quant_min); + const ValueRef r_quant_max = graph.add_scalar(quant_max); + const ValueRef r_eps = graph.add_scalar(eps); + + // Create scale_dtype and zero_point_dtype + const ValueRef r_scale_dtype = + graph.add_scalar(static_cast(at::kFloat)); + const ValueRef r_zero_point_dtype = + graph.add_scalar(static_cast(at::kInt)); + + // Create output tuple + std::vector out_tuple; + + // Create scale and zero_point output tensors + const ValueRef r_scale_out = graph.add_tensor( + reference_scale.sizes().vec(), vkapi::kFloat, out_storage); + const ValueRef r_zero_point_out = graph.add_tensor( + reference_zero_point.sizes().vec(), vkapi::kInt, out_storage); + + out_tuple.push_back(r_scale_out); + out_tuple.push_back(r_zero_point_out); + + const ValueRef r_out_tuple = graph.add_value_list(std::move(out_tuple)); + + VK_GET_OP_FN("torchao.choose_qparams_affine.default") + (graph, + { + r_input.value, + r_mapping_type, + r_block_size, + r_target_dtype, + r_quant_min, + r_quant_max, + r_eps, + r_scale_dtype, + r_zero_point_dtype, + r_out_tuple, + }); + + ValueRef staging_scale = graph.set_output_tensor(r_scale_out); + ValueRef staging_zero_point = graph.set_output_tensor(r_zero_point_out); + + graph.prepare(); + graph.prepack(); + graph.encode_execute(); + + // Copy input data to GPU + graph.copy_into_staging( + r_input.staging, input.const_data_ptr(), input.numel()); + + // Execute the graph + graph.execute(); + + // Copy output data back to CPU + at::Tensor vk_scale = at::empty_like(reference_scale).contiguous(); + at::Tensor vk_zero_point = at::empty_like(reference_zero_point).contiguous(); + + graph.copy_from_staging( + staging_scale, vk_scale.mutable_data_ptr(), vk_scale.numel()); + graph.copy_from_staging( + staging_zero_point, + vk_zero_point.mutable_data_ptr(), + vk_zero_point.numel()); + + // Compare outputs + const bool scale_correct = + at::allclose(reference_scale, vk_scale, /*rtol=*/1e-3, /*atol=*/1e-3); + + // For zero point, we need to compare as integers since zero point should be + // an integer First convert both tensors to int if they aren't already + at::Tensor ref_zp_int = reference_zero_point.to(at::kInt); + at::Tensor vk_zp_int = vk_zero_point.to(at::kInt); + const bool zero_point_correct = at::equal(ref_zp_int, vk_zp_int); + + if (!scale_correct || !zero_point_correct) { + std::cout << "\nFailed with parameters:" << std::endl; + std::cout << " input_sizes: ["; + for (size_t i = 0; i < input_sizes.size(); i++) { + std::cout << input_sizes[i] << (i < input_sizes.size() - 1 ? ", " : ""); + } + std::cout << "]" << std::endl; + std::cout << " block_size: ["; + for (size_t i = 0; i < block_size.size(); i++) { + std::cout << block_size[i] << (i < block_size.size() - 1 ? ", " : ""); + } + std::cout << "]" << std::endl; + std::cout << " mapping_type: " << mapping_type << std::endl; + std::cout << " quant_min: " << quant_min << std::endl; + std::cout << " quant_max: " << quant_max << std::endl; + std::cout << " eps: " << eps << std::endl; + std::cout << " storage type: " + << (in_storage == vkcompute::utils::kBuffer ? "buffer" + : "texture") + << std::endl; + + if (!scale_correct || !zero_point_correct) { + std::cout << "input:" << std::endl; + std::cout << input << std::endl; + + std::cout << "reference_scale:" << std::endl + << reference_scale << std::endl; + std::cout << "vulkan_scale:" << std::endl << vk_scale << std::endl; + + std::cout << "reference_zero_point:" << std::endl + << reference_zero_point << std::endl; + std::cout << "vulkan_zero_point:" << std::endl + << vk_zero_point << std::endl; + } + } + + ASSERT_TRUE(scale_correct); + ASSERT_TRUE(zero_point_correct); +} + +// Wrapper function to test both buffer and texture storage types +void test_vulkan_choose_qparams_affine( + const std::vector& input_sizes, + const std::vector& block_size, + const std::string& mapping_type, + int64_t quant_min, + int64_t quant_max, + double eps, + at::ScalarType in_dtype = at::kFloat) { + // Test with buffer storage for both input and output + test_vulkan_choose_qparams_affine_impl( + input_sizes, + block_size, + mapping_type, + quant_min, + quant_max, + eps, + in_dtype, + vkcompute::utils::kBuffer, + vkcompute::utils::kBuffer); + + // Test with texture storage for input and buffer storage for output + // (shader always uses buffer storage for outputs) + test_vulkan_choose_qparams_affine_impl( + input_sizes, + block_size, + mapping_type, + quant_min, + quant_max, + eps, + in_dtype, + vkcompute::utils::kTexture3D, + vkcompute::utils::kBuffer); +} + +TEST(VulkanChooseQParamsAffineTest, test_1d_asymmetric) { + // 1D: 12 Tensor, block_size is 3 + test_vulkan_choose_qparams_affine( + {12}, // input_sizes + {3}, // block_size + "ASYMMETRIC", // mapping_type + -128, // quant_min (char min) + 127, // quant_max (char max) + 1e-5, // eps + at::kFloat); // input dtype +} + +TEST(VulkanChooseQParamsAffineTest, test_2d_symmetric) { + // 2D: 8x6 Tensor, block_size is 2x3 + test_vulkan_choose_qparams_affine( + {8, 6}, // input_sizes + {2, 3}, // block_size + "SYMMETRIC", // mapping_type + -128, // quant_min (char min) + 127, // quant_max (char max) + 1e-5, // eps + at::kFloat); // input dtype +} + +TEST(VulkanChooseQParamsAffineTest, test_3d_symmetric_no_clipping) { + // 3D: 6x4x6 Tensor, block_size is 3x2x2 + test_vulkan_choose_qparams_affine( + {6, 4, 6}, // input_sizes + {3, 2, 2}, // block_size + "SYMMETRIC_NO_CLIPPING_ERR", // mapping_type + -128, // quant_min (char min) + 127, // quant_max (char max) + 1e-5, // eps + at::kFloat); // input dtype +} + +TEST(VulkanChooseQParamsAffineTest, test_4d_asymmetric) { + // 4D: 4x6x6x6 Tensor, block_size is 2x3x2x3 + test_vulkan_choose_qparams_affine( + {4, 6, 6, 6}, // input_sizes (reduced from 8 to 4 to make test faster) + {2, 3, 2, 3}, // block_size + "ASYMMETRIC", // mapping_type + -128, // quant_min (char min) + 127, // quant_max (char max) + 1e-5, // eps + at::kFloat); // input dtype +} + +TEST(VulkanChooseQParamsAffineTest, test_per_tensor) { + // Per-tensor: block_size equals tensor size + test_vulkan_choose_qparams_affine( + {4, 6, 8}, // input_sizes + {4, 6, 8}, // block_size equals tensor size + "ASYMMETRIC", // mapping_type + -128, // quant_min (char min) + 127, // quant_max (char max) + 1e-5, // eps + at::kFloat); // input dtype +} + +TEST(VulkanChooseQParamsAffineTest, test_per_token) { + // Per-token: block_size is all 1s except last dimension + test_vulkan_choose_qparams_affine( + {4, 6, 8}, // input_sizes + {1, 1, 8}, // block_size is all 1s except last dimension + "ASYMMETRIC", // mapping_type + -128, // quant_min (char min) + 127, // quant_max (char max) + 1e-5, // eps + at::kFloat); // input dtype +} + +// Additional tests for choose_qparams_affine + +TEST(VulkanChooseQParamsAffineTest, test_uint8_range) { + // Test with uint8 range (0-255) + test_vulkan_choose_qparams_affine( + {6, 8}, // input_sizes + {2, 4}, // block_size + "ASYMMETRIC", // mapping_type + 0, // quant_min (uint8 min) + 255, // quant_max (uint8 max) + 1e-5, // eps + at::kFloat); // input dtype +} + +TEST(VulkanChooseQParamsAffineTest, test_int16_range) { + // Test with int16 range (-32768 to 32767) + test_vulkan_choose_qparams_affine( + {6, 8}, // input_sizes + {2, 4}, // block_size + "SYMMETRIC", // mapping_type + -32768, // quant_min (int16 min) + 32767, // quant_max (int16 max) + 1e-5, // eps + at::kFloat); // input dtype +} + +TEST(VulkanChooseQParamsAffineTest, test_larger_eps) { + // Test with larger epsilon value + test_vulkan_choose_qparams_affine( + {6, 8}, // input_sizes + {2, 4}, // block_size + "ASYMMETRIC", // mapping_type + -128, // quant_min + 127, // quant_max + 1e-2, // larger eps + at::kFloat); // input dtype +} + +TEST(VulkanChooseQParamsAffineTest, test_per_channel_first_dim) { + // Per-channel quantization on first dimension + test_vulkan_choose_qparams_affine( + {8, 6, 4}, // input_sizes + {1, 6, 4}, // block_size (per-channel on dim 0) + "SYMMETRIC", // mapping_type + -128, // quant_min + 127, // quant_max + 1e-5, // eps + at::kFloat); // input dtype +} + +TEST(VulkanChooseQParamsAffineTest, test_per_channel_middle_dim) { + // Per-channel quantization on middle dimension + test_vulkan_choose_qparams_affine( + {4, 8, 6}, // input_sizes + {4, 1, 6}, // block_size (per-channel on dim 1) + "SYMMETRIC", // mapping_type + -128, // quant_min + 127, // quant_max + 1e-5, // eps + at::kFloat); // input dtype +} + +TEST(VulkanChooseQParamsAffineTest, test_mixed_block_sizes) { + // Mixed block sizes (some dimensions fully quantized, some partially) + test_vulkan_choose_qparams_affine( + {8, 6, 10}, // input_sizes + {4, 6, 2}, // block_size (mixed: partial, full, partial) + "ASYMMETRIC", // mapping_type + -128, // quant_min + 127, // quant_max + 1e-5, // eps + at::kFloat); // input dtype +} + +TEST(VulkanChooseQParamsAffineTest, test_small_tensor) { + // Test with a small tensor + test_vulkan_choose_qparams_affine( + {2, 3}, // small input_sizes + {2, 3}, // block_size (full tensor) + "ASYMMETRIC", // mapping_type + -128, // quant_min + 127, // quant_max + 1e-5, // eps + at::kFloat); // input dtype +} + +TEST(VulkanChooseQParamsAffineTest, test_asymmetric_narrow_range) { + // Test with a narrow quantization range + test_vulkan_choose_qparams_affine( + {6, 8}, // input_sizes + {2, 4}, // block_size + "ASYMMETRIC", // mapping_type + -10, // quant_min (narrow range) + 10, // quant_max (narrow range) + 1e-5, // eps + at::kFloat); // input dtype +} + +TEST(VulkanChooseQParamsAffineTest, test_symmetric_narrow_range) { + // Test with a narrow quantization range with symmetric mapping + test_vulkan_choose_qparams_affine( + {6, 8}, // input_sizes + {2, 4}, // block_size + "SYMMETRIC", // mapping_type + -10, // quant_min (narrow range) + 10, // quant_max (narrow range) + 1e-5, // eps + at::kFloat); // input dtype +} + +TEST(VulkanChooseQParamsAffineTest, test_symmetric_no_clipping_narrow_range) { + // Test with a narrow quantization range with symmetric no clipping mapping + test_vulkan_choose_qparams_affine( + {6, 8}, // input_sizes + {2, 4}, // block_size + "SYMMETRIC_NO_CLIPPING_ERR", // mapping_type + -10, // quant_min (narrow range) + 10, // quant_max (narrow range) + 1e-5, // eps + at::kFloat); // input dtype +} \ No newline at end of file