diff --git a/backends/vulkan/runtime/graph/ops/glsl/dequantize_buffer.glsl b/backends/vulkan/runtime/graph/ops/glsl/dequantize_buffer.glsl index 94072dfbfea..43e62eadeee 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/dequantize_buffer.glsl +++ b/backends/vulkan/runtime/graph/ops/glsl/dequantize_buffer.glsl @@ -53,6 +53,17 @@ $if MODE == "per_channel": int quant_min; int quant_max; }; +$if MODE == "block_wise": + ${layout_declare_tensor(B, "r", "t_scale", "float", "buffer")} + ${layout_declare_tensor(B, "r", "t_zero_point", "int", "buffer")} + + layout(push_constant) uniform restrict Block { + ivec4 blockSize; // bW, bH, bC, bN + ivec4 numBlocks; // tW/bW, tH/bH, tC/bC, tN/bN + ivec4 blockStride; // pre-computed linear strides for the block grid + int quant_min; + int quant_max; + }; ${layout_declare_ubo(B, "int", "out_numel")} ${layout_declare_ubo(B, "ivec4", "t_in_sizes")} @@ -71,68 +82,60 @@ const lowp ivec4 out_dim_order = unhash_dim_order(out_layout); const lowp ivec4 in_dim_order = unhash_dim_order(in_layout); /* - * DEQUANTIZATION SHADER (BUFFER STORAGE) - * - * This shader converts n-bit integer tensor values back to floating-point representations - * using pre-computed quantization parameters (scale and zero_point). The dequantization - * reconstructs the original floating-point values from their discrete integer representations - * with minimal precision loss. - * - * ALGORITHM: - * 1. Load quantized integer value from buffer - * 2. Apply dequantization formula: value = (qvalue - zero_point) * scale - * 3. Store reconstructed floating-point value to output buffer - * - * WORKGROUP CONFIGURATION: - * - Per-Tensor Mode: - * - Global WG Size: {num_elements, 1, 1} (one thread per tensor element) - * - Local WG Size: Default (typically {64, 1, 1} or based on global WG size) - * - Per-Token Mode: - * - Global WG Size: {num_elements, 1, 1} (one thread per tensor element) - * - Local WG Size: Default (typically {64, 1, 1} or based on global WG size) - * - * SUPPORTED CONFIGURATIONS: - * - Buffer Storage: Uses linear buffer indexing with stride-based tensor access - * - Per-Tensor: Supports any tensor layout through stride calculations and dimension ordering - * - Per-Token: Supports only width packed tensors (packed_dim = 0) and standard axis mapping - * - Scale/zero_point tensors: Must use buffer storage with width packing (packed_dim = 0) - * - * DEQUANTIZATION FORMULA VISUALIZATION: - * For integer range [quant_min, quant_max] mapped back to [min_val, max_val]: - * - * Integer Domain: Floating Point Domain: - * quant_min ──────────────► min_val - * │ │ - * │ scale = (max_val - min_val) / (quant_max - quant_min) - * │ zero_point = quant_min - round(min_val / scale) - * │ │ - * quant_max ──────────────► max_val - * - * Dequantization Process: - * Input: -103 (int8) - * Step 1: qvalue - zero_point = -103 - (-128) = 25 - * Step 2: result * scale = 25 * 0.1 = 2.5 - * Output: 2.5 (float) - * - * PER-TENSOR DEQUANTIZATION: - * - Single scale and zero_point values for entire tensor - * - All elements use same dequantization parameters - * - Parameters passed as push constants for efficiency - * - Formula: value = (qvalue - zero_point) * scale - * - * PER-TOKEN DEQUANTIZATION: - * - Separate scale and zero_point for each token - * - Token = all elements except last dimension (e.g., for [B,S,H]: B*S tokens of H elements) - * - Parameters stored in buffer arrays indexed by token_id - * - Each thread calculates its token_id from tensor coordinates - * - Formula: value = (qvalue - zero_point[token_id]) * scale[token_id] - * - * Token ID calculation for element at tensor index (w, z, y, x): - * - 4D tensor: token_id = w * (sizes.z * sizes.y) + z * sizes.y + y - * - 3D tensor: token_id = z * sizes.y + y - * - 2D tensor: token_id = y - * - 1D tensor: token_id = 0 - */ + Dequantization Shader (Buffer Storage) + This shader converts n-bit integer tensor values back to floating-point representations + using pre-computed quantization parameters (scale and zero_point). The dequantization + reconstructs the original floating-point values from their discrete integer representations + with minimal precision loss. + + Important Considerations: + (+) All input tensors are assumed to be WIDTH_PACKED (i.e., contiguous in the last dimension) + (+) The axis map layout is assumed to be a standard layout for scales and zero_points + (++) The scale and zero_point tensors must be implemented as buffers + + Workgroup Configuration: + - dequantize_per_tensor + This mode reverses the uniform quantization applied across the entire tensor by using the + single scale and zero_point values to convert quantized integer values back to their original + floating-point representation. + + (*) global_wg_size: default + (*) local_wg_size: default + + - dequantize_per_token + This mode reverses the quantization applied individually to each token (or element) in the + input by using separate scale and zero_point values for each token. For a tensor of shape + [B, S, H], it applies the inverse transformation token-wise across the B*S tokens, converting + quantized values back to their original floating-point representation for each group of H + elements independently. + + (*) global_wg_size: default + (*) local_wg_size: default + + - dequantize_per_channel + This mode reverses the quantization applied separately to each channel of the input tensor + by using distinct scale and zero_point values for each channel. For a tensor of shape + [B, C, H, W] with axis = 1, it applies the inverse transformation channel-wise across the C + channels, converting quantized values back to their original floating-point representation + independently for each channel. + + (*) global_wg_size: default + (*) local_wg_size: default + + - dequantize_block_wise + This mode reverses the block-wise quantization applied to groups of elements by using separate + scale and zero_point values for each block. Equivalent to dequantize_affine, it applies the + inverse affine transformation per block to convert quantized values back to their original + floating-point representation. For example, if the tensor shape is [6, 9, 4] and + blockSize = [3, 3, 2], the tensor is divided into 12 blocks, each containing 18 elements, + and dequantization is performed independently on each block. + + (*) global_wg_size: default + (*) local_wg_size: default + + Dequantization Formula: + value = (qvalue - zero_point) * scale +*/ #ifdef per_tensor @@ -187,7 +190,7 @@ void dequantize_per_token() { t_out[out_bufi] = value; } -#else // per_channel +#elif defined(per_channel) void dequantize_per_channel() { const int out_bufi = int(gl_GlobalInvocationID.x); @@ -226,6 +229,29 @@ void dequantize_per_channel() { t_out[out_bufi] = value; } +#else // block_wise + +void dequantize_block_wise() { + const int out_bufi = int(gl_GlobalInvocationID.x); + + if (out_bufi >= out_numel) { + return; + } + + const ivec4 out_tidx = bufi_to_tidx(out_bufi, t_out_strides, out_dim_order); + const int in_bufi = tidx_to_bufi(out_tidx, t_in_strides); + + IN_T qvalue = t_in[in_bufi]; + + const ivec4 bcoord = out_tidx / blockSize; + + const int block_id = bcoord.x * blockStride.x + bcoord.y * blockStride.y + bcoord.z * blockStride.z + bcoord.w * blockStride.w; + + const OUT_T value = dequantize_val(qvalue, t_scale[block_id], t_zero_point[block_id]); + + t_out[out_bufi] = value; +} + #endif void main() { diff --git a/backends/vulkan/runtime/graph/ops/glsl/dequantize_buffer.yaml b/backends/vulkan/runtime/graph/ops/glsl/dequantize_buffer.yaml index b9a53217452..999c59d3b79 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/dequantize_buffer.yaml +++ b/backends/vulkan/runtime/graph/ops/glsl/dequantize_buffer.yaml @@ -19,3 +19,5 @@ dequantize_buffer: MODE: per_token - NAME: dequantize_per_channel_buffer MODE: per_channel + - NAME: dequantize_block_wise_buffer + MODE: block_wise diff --git a/backends/vulkan/runtime/graph/ops/glsl/dequantize_texture.glsl b/backends/vulkan/runtime/graph/ops/glsl/dequantize_texture.glsl index 5c978c61846..20bf6c87e26 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/dequantize_texture.glsl +++ b/backends/vulkan/runtime/graph/ops/glsl/dequantize_texture.glsl @@ -56,6 +56,17 @@ $if MODE == "per_channel": int quant_min; int quant_max; }; +$if MODE == "block_wise": + ${layout_declare_tensor(B, "r", "t_scale", "float", "buffer")} + ${layout_declare_tensor(B, "r", "t_zero_point", "int", "buffer")} + + layout(push_constant) uniform restrict Block { + ivec4 blockSize; // bW, bH, bC, bN + ivec4 numBlocks; // tW/bW, tH/bH, tC/bC, tN/bN + ivec4 blockStride; // pre-computed linear strides for the block grid + int quant_min; + int quant_max; + }; ${layout_declare_ubo(B, "ivec3", "t_in_limits")} ${layout_declare_ubo(B, "ivec3", "t_out_limits")} @@ -201,7 +212,7 @@ void dequantize_per_token() { write_texel(t_out, pos, outtex); } -#else // per_channel +#elif defined(per_channel) void dequantize_per_channel() { const ivec3 pos = ivec3(gl_GlobalInvocationID); @@ -292,6 +303,39 @@ void dequantize_per_channel() { write_texel(t_out, pos, outtex); } +#else // block_wise + +void dequantize_block_wise() { + const ivec3 pos = ivec3(gl_GlobalInvocationID); + + if (any(greaterThanEqual(pos, t_in_limits))) + return; + + IVEC4_T intex = load_texel(t_in, pos); + FVEC4_T outtex; + + ivec4 base_tidx = ivec4(pos.x * 4, pos.y, pos.z, 0); + int foldedZ = pos.z; + + int C_total = numBlocks.z * blockSize.z; + + [[unroll]] for (int i = 0; i < 4; ++i) { + ivec4 tidx = ivec4(base_tidx.x + i, base_tidx.y, (foldedZ % C_total), (foldedZ / C_total)); + + ivec4 bcoord = tidx / blockSize; + int block_id = bcoord.x * blockStride.x + bcoord.y * blockStride.y + bcoord.z * blockStride.z + bcoord.w * blockStride.w; + + IN_T qvalue = IN_T(intex[i]); + OUT_T value = dequantize_val(qvalue, t_scale[block_id], t_zero_point[block_id]); + $if OUT_DTYPE == "double": + outtex[i] = float(value); + $else: + outtex[i] = value; + } + + write_texel(t_out, pos, outtex); +} + #endif void main() { diff --git a/backends/vulkan/runtime/graph/ops/glsl/dequantize_texture.yaml b/backends/vulkan/runtime/graph/ops/glsl/dequantize_texture.yaml index 88ccc6e3274..9b624762192 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/dequantize_texture.yaml +++ b/backends/vulkan/runtime/graph/ops/glsl/dequantize_texture.yaml @@ -19,3 +19,5 @@ dequantize_texture: MODE: per_token - NAME: dequantize_per_channel_texture3d MODE: per_channel + - NAME: dequantize_block_wise_texture3d + MODE: block_wise diff --git a/backends/vulkan/runtime/graph/ops/impl/Dequantize.cpp b/backends/vulkan/runtime/graph/ops/impl/Dequantize.cpp index 7edb9b2f70d..61fd76145a4 100644 --- a/backends/vulkan/runtime/graph/ops/impl/Dequantize.cpp +++ b/backends/vulkan/runtime/graph/ops/impl/Dequantize.cpp @@ -17,38 +17,59 @@ namespace vkcompute { -void resize_dequantize_output( +void resize_dequantize_node( ComputeGraph* graph, const std::vector& args, const std::vector& extra_args) { (void)extra_args; - const ValueRef out = args.at(0).refs.at(0); - const ValueRef in = args.at(1).refs.at(0); - graph->virtual_resize(out, graph->sizes_of(in)); + + vTensorPtr out = graph->get_tensor(args[0].refs[0]); + vTensorPtr in = graph->get_tensor(args[1].refs[0]); + + out->virtual_resize(in->sizes()); } -utils::uvec3 dequantize_per_channel_global_wg_size( +utils::uvec3 dequantize_per_channel_local_wg_size( ComputeGraph* graph, const vkapi::ShaderInfo& shader, + const utils::uvec3& global_workgroup_size, const std::vector& args, const std::vector& resize_args) { + (void)args; (void)resize_args; - const ValueRef out = args.at(0).refs.at(0); - utils::uvec3 global_wg_size = graph->create_global_wg_size(out); + const ValueRef input = args.at(1).refs.at(0); - return global_wg_size; + utils::uvec3 local_wg_size = + graph->create_local_wg_size(global_workgroup_size); + + // WORKAROUND: The CommandBuffer::dispatch function divides + // global_workgroup_size by local_workgroup_size to get the number of + // workgroups to dispatch. We need to ensure that we dispatch the correct + // number of workgroups in the Z dimension to cover all batch-channel + // combinations. + // + // If local_wg_size[2] > 1, then div_up(global_workgroup_size[2], + // local_wg_size[2]) might reduce the number of workgroups dispatched. To + // ensure we dispatch global_workgroup_size[2] workgroups in the Z dimension, + // we set local_wg_size[2] = 1. + const auto input_sizes = graph->sizes_of(input); + if (input_sizes.size() == 4 && !graph->is_buffer_storage(input) && + global_workgroup_size[2] > 1) { + local_wg_size[2] = 1; + } + + return local_wg_size; } -utils::uvec3 dequantize_per_channel_local_wg_size( +utils::uvec3 dequantize_block_wise_local_wg_size( ComputeGraph* graph, const vkapi::ShaderInfo& shader, const utils::uvec3& global_workgroup_size, const std::vector& args, const std::vector& resize_args) { - (void)args; + (void)shader; (void)resize_args; - const ValueRef input = args.at(1).refs.at(0); utils::uvec3 local_wg_size = @@ -56,16 +77,17 @@ utils::uvec3 dequantize_per_channel_local_wg_size( // WORKAROUND: The CommandBuffer::dispatch function divides // global_workgroup_size by local_workgroup_size to get the number of - // workgroups to dispatch. For per-channel dequantization along the batch - // axis, we need to ensure that we dispatch the correct number of workgroups - // in the Z dimension to cover all batch-channel combinations. + // workgroups to dispatch. We need to ensure that we dispatch the correct + // number of workgroups in the Z dimension to cover all batch-channel + // combinations. // // If local_wg_size[2] > 1, then div_up(global_workgroup_size[2], // local_wg_size[2]) might reduce the number of workgroups dispatched. To // ensure we dispatch global_workgroup_size[2] workgroups in the Z dimension, // we set local_wg_size[2] = 1. const auto input_sizes = graph->sizes_of(input); - if (global_workgroup_size[2] > 1 && input_sizes[3] > 0) { + if (input_sizes.size() == 4 && !graph->is_buffer_storage(input) && + global_workgroup_size[2] > 1) { local_wg_size[2] = 1; } @@ -131,7 +153,7 @@ void add_dequantize_per_tensor_node( // Resize Args {}, // Resizing Logic - resize_dequantize_output)); + resize_dequantize_node)); } void add_dequantize_per_token_node( @@ -161,25 +183,18 @@ void add_dequantize_per_token_node( graph.sizes_ubo(input), graph.strides_ubo(input), graph.sizes_ubo(output), - graph.strides_ubo(output), - }; - push_constants = { - PushConstantDataInfo(&num_tokens, sizeof(int)), - PushConstantDataInfo(&quant_min_val, sizeof(int)), - PushConstantDataInfo(&quant_max_val, sizeof(int)), - }; + graph.strides_ubo(output)}; } else { param_ubos = { - graph.logical_limits_ubo(input), - graph.logical_limits_ubo(output), - }; - push_constants = { - PushConstantDataInfo(&num_tokens, sizeof(int)), - PushConstantDataInfo(&quant_min_val, sizeof(int)), - PushConstantDataInfo(&quant_max_val, sizeof(int)), - }; + graph.logical_limits_ubo(input), graph.logical_limits_ubo(output)}; } + push_constants = { + PushConstantDataInfo(&num_tokens, sizeof(int)), + PushConstantDataInfo(&quant_min_val, sizeof(int)), + PushConstantDataInfo(&quant_max_val, sizeof(int)), + }; + vkapi::SpecVarList spec_vars = { graph.hashed_layout_of(output), graph.hashed_layout_of(input), @@ -203,7 +218,7 @@ void add_dequantize_per_token_node( // Resize Args {}, // Resizing Logic - resize_dequantize_output)); + resize_dequantize_node)); } void add_dequantize_per_channel_node( @@ -252,27 +267,19 @@ void add_dequantize_per_channel_node( graph.sizes_ubo(input), graph.strides_ubo(input), graph.sizes_ubo(output), - graph.strides_ubo(output), - }; - push_constants = { - PushConstantDataInfo(&axis_whcn, sizeof(int)), - PushConstantDataInfo(&num_channels, sizeof(int)), - PushConstantDataInfo(&quant_min_val, sizeof(int)), - PushConstantDataInfo(&quant_max_val, sizeof(int)), - }; + graph.strides_ubo(output)}; } else { param_ubos = { - graph.logical_limits_ubo(input), - graph.logical_limits_ubo(output), - }; - push_constants = { - PushConstantDataInfo(&axis_whcn, sizeof(int)), - PushConstantDataInfo(&num_channels, sizeof(int)), - PushConstantDataInfo(&quant_min_val, sizeof(int)), - PushConstantDataInfo(&quant_max_val, sizeof(int)), - }; + graph.logical_limits_ubo(input), graph.logical_limits_ubo(output)}; } + push_constants = { + PushConstantDataInfo(&axis_whcn, sizeof(int)), + PushConstantDataInfo(&num_channels, sizeof(int)), + PushConstantDataInfo(&quant_min_val, sizeof(int)), + PushConstantDataInfo(&quant_max_val, sizeof(int)), + }; + vkapi::SpecVarList spec_vars = { graph.hashed_layout_of(output), graph.hashed_layout_of(input), @@ -281,7 +288,7 @@ void add_dequantize_per_channel_node( graph.execute_nodes().emplace_back(new DynamicDispatchNode( graph, VK_KERNEL_FROM_STR(kernel_name), - dequantize_per_channel_global_wg_size, + default_pick_global_wg_size, dequantize_per_channel_local_wg_size, // Inputs and Outputs {{output, vkapi::kWrite}, @@ -296,7 +303,94 @@ void add_dequantize_per_channel_node( // Resize Args {}, // Resizing Logic - resize_dequantize_output)); + resize_dequantize_node)); +} + +void add_dequantize_block_wise_node( + ComputeGraph& graph, + const ValueRef& input, + const ValueRef& block_size, + const ValueRef& scale, + const ValueRef& zero_point, + const ValueRef& quant_min, + const ValueRef& quant_max, + const ValueRef& output) { + std::string kernel_name("dequantize_block_wise"); + add_storage_type_suffix(kernel_name, graph.storage_type_of(input)); + add_dtype_suffix(kernel_name, graph.dtype_of(input)); + add_dtype_suffix(kernel_name, graph.dtype_of(output)); + + int quant_min_val = static_cast(graph.get_int(quant_min)); + int quant_max_val = static_cast(graph.get_int(quant_max)); + + const auto input_sizes = graph.sizes_of(input); + const auto block_size_list = graph.get_int_list(block_size); + + // Convert dimensions to WHCN order for shader + utils::ivec4 block_size_vec = utils::make_whcn_ivec4(*block_size_list); + utils::ivec4 tensor_size_whcn = utils::make_whcn_ivec4(input_sizes); + + // Calculate numBlocks: tensorSize / blockSize (both in WHCN order) + utils::ivec4 num_blocks_vec = { + tensor_size_whcn[0] / block_size_vec[0], + tensor_size_whcn[1] / block_size_vec[1], + tensor_size_whcn[2] / block_size_vec[2], + tensor_size_whcn[3] / block_size_vec[3]}; + + // Calculate blockStride: pre-computed linear strides for the block grid + utils::ivec4 block_stride_vec = { + 1, + num_blocks_vec[0], + num_blocks_vec[0] * num_blocks_vec[1], + num_blocks_vec[0] * num_blocks_vec[1] * num_blocks_vec[2]}; + + vkapi::ParamsBindList param_ubos; + std::vector push_constants; + + if (graph.is_buffer_storage(input)) { + param_ubos = { + graph.numel_ubo(input), + graph.sizes_ubo(input), + graph.strides_ubo(input), + graph.sizes_ubo(output), + graph.strides_ubo(output)}; + } else { + param_ubos = { + graph.logical_limits_ubo(input), graph.logical_limits_ubo(output)}; + } + + push_constants = { + PushConstantDataInfo(&block_size_vec, sizeof(block_size_vec)), + PushConstantDataInfo(&num_blocks_vec, sizeof(num_blocks_vec)), + PushConstantDataInfo(&block_stride_vec, sizeof(block_stride_vec)), + PushConstantDataInfo(&quant_min_val, sizeof(int)), + PushConstantDataInfo(&quant_max_val, sizeof(int)), + }; + + vkapi::SpecVarList spec_vars = { + graph.hashed_layout_of(output), + graph.hashed_layout_of(input), + }; + + graph.execute_nodes().emplace_back(new DynamicDispatchNode( + graph, + VK_KERNEL_FROM_STR(kernel_name), + default_pick_global_wg_size, + dequantize_block_wise_local_wg_size, + // Inputs and Outputs + {{output, vkapi::kWrite}, + {input, vkapi::kRead}, + {{scale, zero_point}, vkapi::kRead}}, + // Shader param buffers + param_ubos, + // Push Constants + push_constants, + // Specialization Constants + spec_vars, + // Resize Args + {}, + // Resizing Logic + resize_dequantize_node)); } void dequantize_per_tensor_impl( @@ -308,31 +402,39 @@ void dequantize_per_tensor_impl( const ValueRef zero_point = args[arg_idx++]; const ValueRef quant_min = args[arg_idx++]; const ValueRef quant_max = args[arg_idx++]; - const ValueRef dtype = args[arg_idx++]; // Added dtype parameter - const ValueRef output_dtype = args[arg_idx++]; // Added output_dtype parameter + const ValueRef dtype = args[arg_idx++]; + const ValueRef output_dtype = args[arg_idx++]; const ValueRef output = args[arg_idx++]; // Suppress unused variable warnings - dtype and output_dtype are inferred - // from output (void)dtype; (void)output_dtype; // Check tensor types VK_CHECK_COND(graph.val_is_tensor(input)); + VK_CHECK_COND(graph.val_is_tensor(scale)); + VK_CHECK_COND(graph.val_is_tensor(zero_point)); VK_CHECK_COND(graph.val_is_tensor(output)); // Verify input is an integer type VK_CHECK_COND( graph.dtype_of(input) == vkapi::kByte || graph.dtype_of(input) == vkapi::kChar || - graph.dtype_of(input) == vkapi::kShort || graph.dtype_of(input) == vkapi::kInt); - // Verify output is a floating point type - VK_CHECK_COND( - graph.dtype_of(output) == vkapi::kHalf || - graph.dtype_of(output) == vkapi::kFloat || - graph.dtype_of(output) == vkapi::kDouble); + // Check that scale and zero_point have buffer storage and width packing + VK_CHECK_COND(graph.is_buffer_storage(scale)); + VK_CHECK_COND(graph.packed_dim_of(scale) == WHCN::kWidthDim); + VK_CHECK_COND(graph.is_buffer_storage(zero_point)); + VK_CHECK_COND(graph.packed_dim_of(zero_point) == WHCN::kWidthDim); + + // Check that tensors with texture storage have standard axis map + if (!graph.is_buffer_storage(input)) { + VK_CHECK_COND(graph.has_standard_axis_map(input)); + } + if (!graph.is_buffer_storage(output)) { + VK_CHECK_COND(graph.has_standard_axis_map(output)); + } add_dequantize_per_tensor_node( graph, input, scale, zero_point, quant_min, quant_max, output); @@ -347,12 +449,11 @@ void dequantize_per_token_impl( const ValueRef zero_point = args[arg_idx++]; const ValueRef quant_min = args[arg_idx++]; const ValueRef quant_max = args[arg_idx++]; - const ValueRef dtype = args[arg_idx++]; // Added dtype parameter - const ValueRef output_dtype = args[arg_idx++]; // Added output_dtype parameter + const ValueRef dtype = args[arg_idx++]; + const ValueRef output_dtype = args[arg_idx++]; const ValueRef output = args[arg_idx++]; // Suppress unused variable warnings - dtype and output_dtype are inferred - // from output (void)dtype; (void)output_dtype; @@ -366,15 +467,8 @@ void dequantize_per_token_impl( VK_CHECK_COND( graph.dtype_of(input) == vkapi::kByte || graph.dtype_of(input) == vkapi::kChar || - graph.dtype_of(input) == vkapi::kShort || graph.dtype_of(input) == vkapi::kInt); - // Verify output is a floating point type - VK_CHECK_COND( - graph.dtype_of(output) == vkapi::kHalf || - graph.dtype_of(output) == vkapi::kFloat || - graph.dtype_of(output) == vkapi::kDouble); - // Check that scale and zero_point have buffer storage and width packing VK_CHECK_COND(graph.is_buffer_storage(scale)); VK_CHECK_COND(graph.packed_dim_of(scale) == WHCN::kWidthDim); @@ -430,12 +524,11 @@ void dequantize_per_channel_impl( const ValueRef axis = args[arg_idx++]; const ValueRef quant_min = args[arg_idx++]; const ValueRef quant_max = args[arg_idx++]; - const ValueRef dtype = args[arg_idx++]; // Added dtype parameter - const ValueRef output_dtype = args[arg_idx++]; // Added output_dtype parameter + const ValueRef dtype = args[arg_idx++]; + const ValueRef output_dtype = args[arg_idx++]; const ValueRef output = args[arg_idx++]; // Suppress unused variable warnings - dtype and output_dtype are inferred - // from output (void)dtype; (void)output_dtype; @@ -449,15 +542,8 @@ void dequantize_per_channel_impl( VK_CHECK_COND( graph.dtype_of(input) == vkapi::kByte || graph.dtype_of(input) == vkapi::kChar || - graph.dtype_of(input) == vkapi::kShort || graph.dtype_of(input) == vkapi::kInt); - // Verify output is a floating point type - VK_CHECK_COND( - graph.dtype_of(output) == vkapi::kHalf || - graph.dtype_of(output) == vkapi::kFloat || - graph.dtype_of(output) == vkapi::kDouble); - // Check that scale and zero_point have buffer storage and width packing VK_CHECK_COND(graph.is_buffer_storage(scale)); VK_CHECK_COND(graph.packed_dim_of(scale) == WHCN::kWidthDim); @@ -513,8 +599,7 @@ void dequantize_affine_impl( const std::vector& args) { int arg_idx = 0; const ValueRef input = args[arg_idx++]; - const ValueRef block_size = - args[arg_idx++]; // SymInt[] - ignored for per-tensor + const ValueRef block_size = args[arg_idx++]; const ValueRef scale = args[arg_idx++]; const ValueRef zero_point = args[arg_idx++]; const ValueRef input_dtype = args[arg_idx++]; @@ -529,33 +614,61 @@ void dequantize_affine_impl( // Check tensor types VK_CHECK_COND(graph.val_is_tensor(input)); + VK_CHECK_COND(graph.val_is_tensor(scale)); + VK_CHECK_COND(graph.val_is_tensor(zero_point)); VK_CHECK_COND(graph.val_is_tensor(output)); // Verify input is an integer type VK_CHECK_COND( graph.dtype_of(input) == vkapi::kByte || graph.dtype_of(input) == vkapi::kChar || - graph.dtype_of(input) == vkapi::kShort || graph.dtype_of(input) == vkapi::kInt); - // Verify output is a floating point type - VK_CHECK_COND( - graph.dtype_of(output) == vkapi::kHalf || - graph.dtype_of(output) == vkapi::kFloat || - graph.dtype_of(output) == vkapi::kDouble); + // Check that scale and zero_point have buffer storage and width packing + VK_CHECK_COND(graph.is_buffer_storage(scale)); + VK_CHECK_COND(graph.packed_dim_of(scale) == WHCN::kWidthDim); + VK_CHECK_COND(graph.is_buffer_storage(zero_point)); + VK_CHECK_COND(graph.packed_dim_of(zero_point) == WHCN::kWidthDim); + + // Check that tensors with texture storage have standard axis map + if (!graph.is_buffer_storage(input)) { + VK_CHECK_COND(graph.has_standard_axis_map(input)); + } + if (!graph.is_buffer_storage(output)) { + VK_CHECK_COND(graph.has_standard_axis_map(output)); + } - // Check if this is per-tensor quantization (only supported granularity) - // block_size should equal input tensor dimensions for per-tensor quantization + // Verify block_size is valid (each dimension must divide evenly into input + // size) const auto input_sizes = graph.sizes_of(input); const auto block_size_list = graph.get_int_list(block_size); VK_CHECK_COND(block_size_list->size() == input_sizes.size()); + for (size_t i = 0; i < input_sizes.size(); i++) { - VK_CHECK_COND((*block_size_list)[i] == input_sizes[i]); + if ((*block_size_list)[i] > 1) { + VK_CHECK_COND( + input_sizes[i] % (*block_size_list)[i] == 0, + "Input size at dimension ", + i, + " (", + input_sizes[i], + ") must be divisible by block_size at dimension ", + i, + " (", + (*block_size_list)[i], + ")"); + } } - // Default to per-tensor dequantization for TorchAO affine ops - add_dequantize_per_tensor_node( - graph, input, scale, zero_point, quant_min, quant_max, output); + add_dequantize_block_wise_node( + graph, + input, + block_size, + scale, + zero_point, + quant_min, + quant_max, + output); } REGISTER_OPERATORS { diff --git a/backends/vulkan/test/op_tests/quantize_affine_test.cpp b/backends/vulkan/test/op_tests/quantize_affine_test.cpp index cb782a92ba4..8a54774d703 100644 --- a/backends/vulkan/test/op_tests/quantize_affine_test.cpp +++ b/backends/vulkan/test/op_tests/quantize_affine_test.cpp @@ -134,6 +134,110 @@ at::Tensor quantize_affine_reference_impl( return q; } +at::Tensor dequantize_affine_reference_impl( + const at::Tensor& input_, + const std::vector& block_size, + const at::Tensor& scale, + const c10::optional& zero_point_opt, + int64_t quant_min, + int64_t quant_max, + at::ScalarType out_dtype, + c10::optional zero_point_domain_opt = std::string("INT")) { + const int64_t ndim = input_.dim(); + _check_dims("input", block_size.size(), ndim); + + VK_CHECK_COND( + input_.scalar_type() == at::kByte || input_.scalar_type() == at::kChar || + input_.scalar_type() == at::kShort || + input_.scalar_type() == at::kInt, + "Unsupported input dtype: ", + input_.dtype()); + + VK_CHECK_COND( + out_dtype == at::kFloat || out_dtype == at::kHalf || + out_dtype == at::kBFloat16, + "Unsupported output dtype: ", + out_dtype); + + auto zero_point_domain = + zero_point_domain_opt.has_value() ? *zero_point_domain_opt : "INT"; + + bool has_zp = zero_point_opt.has_value(); + VK_CHECK_COND( + has_zp || zero_point_domain == "NONE" || zero_point_domain == "", + "zero_point must be supplied unless zero_point_domain is NONE or null"); + + at::Tensor input = input_.contiguous(); + + std::vector shape_for_reduction; + std::vector reduction_dims; + int64_t cur_dim = 0; + + auto in_sizes = input.sizes(); + for (int64_t i = 0; i < ndim; ++i) { + const int64_t blk = block_size[i]; + const int64_t dim = in_sizes[i]; + + if (blk != dim && blk > 1) { + VK_CHECK_COND( + dim % blk == 0, + "Input size ", + dim, + " is not divisible by block_size ", + blk, + " at dimension ", + i); + shape_for_reduction.push_back(dim / blk); + shape_for_reduction.push_back(blk); + reduction_dims.push_back(cur_dim + 1); + cur_dim += 2; + } else { + shape_for_reduction.push_back(dim); + if (blk != 1) { + reduction_dims.push_back(cur_dim); + } + cur_dim += 1; + } + } + + at::Tensor input_reshaped = input.view(shape_for_reduction); + + std::vector shape_after_reduction = shape_for_reduction; + for (int64_t d : reduction_dims) { + shape_after_reduction[d] = 1; + } + + at::Tensor scale_b = scale.view(shape_after_reduction).to(out_dtype); + + at::Tensor zp_b; + if (has_zp) { + zp_b = (*zero_point_opt).view(shape_after_reduction).to(out_dtype); + } + + at::Tensor input_fp = input_reshaped.to(out_dtype); + at::Tensor dq; + + if (zero_point_domain == "INT") { + VK_CHECK_COND(has_zp, "INT zero_point_domain requires zero_point tensor"); + dq = (input_fp - zp_b) * scale_b; + } else if (zero_point_domain == "NONE" || zero_point_domain.empty()) { + VK_CHECK_COND( + !has_zp, "zero_point must be None when domain is NONE / null"); + dq = input_fp * scale_b; + } else { + VK_CHECK_COND( + has_zp && zero_point_domain == "FLOAT", + "zero_point_domain must be INT, FLOAT, NONE or null"); + const float mid_point = (quant_max + quant_min + 1) * 0.5f; + at::Tensor min_val = zp_b - scale_b * mid_point; + dq = input_fp * scale_b + min_val; + } + + dq = dq.view(in_sizes); + + return dq; +} + // Wrapper function to maintain compatibility with existing test code (above is // a good reference for how the python implementation works) at::Tensor quantize_affine_reference_impl( @@ -155,6 +259,26 @@ at::Tensor quantize_affine_reference_impl( std::string("INT")); } +// Wrapper function for dequantize_affine +at::Tensor dequantize_affine_reference_impl( + const at::Tensor& input, + const std::vector& block_size, + const at::Tensor& scale, + const at::Tensor& zero_point, + int64_t quant_min, + int64_t quant_max, + at::ScalarType dtype) { + return dequantize_affine_reference_impl( + input, + block_size, + scale, + c10::optional(zero_point), + quant_min, + quant_max, + dtype, + std::string("INT")); +} + void test_vulkan_quantize_affine_impl( const std::vector& input_sizes, const std::vector& block_size, @@ -440,3 +564,296 @@ TEST(VulkanQuantizeAffineTest, test_4d_quantization) { at::kFloat, // input dtype at::kChar); // output dtype } + +void test_vulkan_dequantize_affine_impl( + const std::vector& input_sizes, + const std::vector& block_size, + const std::vector& scales, + const std::vector& zero_points, + int64_t quant_min, + int64_t quant_max, + at::ScalarType in_dtype = at::kChar, + at::ScalarType out_dtype = at::kFloat, + const vkcompute::utils::StorageType in_storage = + vkcompute::utils::kTexture3D, + const vkcompute::utils::StorageType out_storage = + vkcompute::utils::kTexture3D) { + // Create input tensor with random integer values within quant_min and + // quant_max + std::vector input_sizes_int64( + input_sizes.begin(), input_sizes.end()); + at::Tensor input = at::randint( + quant_min, + quant_max + 1, + input_sizes_int64, + at::device(at::kCPU).dtype(in_dtype)); + + // Create scale and zero_point tensors + at::Tensor scale_tensor = + at::tensor(scales, at::device(at::kCPU).dtype(at::kFloat)); + at::Tensor zero_point_tensor = + at::tensor(zero_points, at::device(at::kCPU).dtype(at::kInt)); + + // Get reference output + at::Tensor reference_out = dequantize_affine_reference_impl( + input, + block_size, + scale_tensor, + zero_point_tensor, + quant_min, + quant_max, + out_dtype); + + using namespace vkcompute; + + GraphConfig config; + config.set_storage_type_override(in_storage); + ComputeGraph graph(config); + + IOValueRef r_input = graph.add_input_tensor( + input.sizes().vec(), from_at_scalartype(input.scalar_type()), in_storage); + + // Create block_size as IntList instead of Tensor + std::vector block_size_copy(block_size); + const ValueRef r_block_size = + graph.add_scalar_list(std::move(block_size_copy)); + + IOValueRef r_scale = graph.add_input_tensor( + scale_tensor.sizes().vec(), + vkapi::kFloat, + utils::kBuffer, + utils::kWidthPacked); + IOValueRef r_zero_point = graph.add_input_tensor( + zero_point_tensor.sizes().vec(), + vkapi::kInt, + utils::kBuffer, + utils::kWidthPacked); + + // Create input_dtype scalar + const ValueRef r_input_dtype = + graph.add_scalar(static_cast(in_dtype)); + const ValueRef r_quant_min = graph.add_scalar(quant_min); + const ValueRef r_quant_max = graph.add_scalar(quant_max); + const ValueRef r_output_dtype = + graph.add_scalar(static_cast(out_dtype)); + + const ValueRef r_out = graph.add_tensor( + input.sizes().vec(), from_at_scalartype(out_dtype), out_storage); + + // Match the argument order in dequantize_affine_impl in Dequantize.cpp: + // input, block_size, scale, zero_point, input_dtype, quant_min, quant_max, + // output_dtype, output + VK_GET_OP_FN("torchao.dequantize_affine.default") + (graph, + { + r_input.value, + r_block_size, + r_scale.value, + r_zero_point.value, + r_input_dtype, + r_quant_min, + r_quant_max, + r_output_dtype, + r_out, + }); + + ValueRef staging_out = graph.set_output_tensor(r_out); + + graph.prepare(); + graph.prepack(); + graph.encode_execute(); + + // Copy input data to GPU + graph.copy_into_staging( + r_input.staging, input.const_data_ptr(), input.numel()); + + // Copy scale tensor to GPU + graph.copy_into_staging( + r_scale.staging, scale_tensor.const_data_ptr(), scale_tensor.numel()); + + // Copy zero_point tensor to GPU + graph.copy_into_staging( + r_zero_point.staging, + zero_point_tensor.const_data_ptr(), + zero_point_tensor.numel()); + + // Execute the graph + graph.execute(); + + // Copy output data back to CPU + at::Tensor vk_out = at::empty_like(reference_out).contiguous(); + graph.copy_from_staging( + staging_out, vk_out.mutable_data_ptr(), vk_out.numel()); + + // Compare outputs + const bool output_correct = + at::allclose(reference_out, vk_out, /*rtol=*/1e-5, /*atol=*/1e-5); + if (!output_correct) { + std::cout << "\nFailed with parameters:" << std::endl; + std::cout << " input_sizes: ["; + for (size_t i = 0; i < input_sizes.size(); i++) { + std::cout << input_sizes[i] << (i < input_sizes.size() - 1 ? ", " : ""); + } + std::cout << "]" << std::endl; + std::cout << " block_size: ["; + for (size_t i = 0; i < block_size.size(); i++) { + std::cout << block_size[i] << (i < block_size.size() - 1 ? ", " : ""); + } + std::cout << "]" << std::endl; + std::cout << " scales: ["; + for (size_t i = 0; i < scales.size(); i++) { + std::cout << scales[i] << (i < scales.size() - 1 ? ", " : ""); + } + std::cout << "]" << std::endl; + std::cout << " zero_points: ["; + for (size_t i = 0; i < zero_points.size(); i++) { + std::cout << zero_points[i] << (i < zero_points.size() - 1 ? ", " : ""); + } + std::cout << "]" << std::endl; + std::cout << " quant_min: " << quant_min << std::endl; + std::cout << " quant_max: " << quant_max << std::endl; + std::cout << " storage type: " + << (in_storage == vkcompute::utils::kBuffer ? "buffer" + : "texture") + << std::endl; + + std::cout << "input:" << std::endl << input << std::endl; + std::cout << "reference:" << std::endl << reference_out << std::endl; + std::cout << "vulkan:" << std::endl << vk_out << std::endl; + } + + ASSERT_TRUE(output_correct); +} + +// Wrapper function to test both buffer and texture storage types +void test_vulkan_dequantize_affine( + const std::vector& input_sizes, + const std::vector& block_size, + const std::vector& scales, + const std::vector& zero_points, + int64_t quant_min, + int64_t quant_max, + at::ScalarType in_dtype = at::kChar, + at::ScalarType out_dtype = at::kFloat) { + // Test with buffer storage + test_vulkan_dequantize_affine_impl( + input_sizes, + block_size, + scales, + zero_points, + quant_min, + quant_max, + in_dtype, + out_dtype, + vkcompute::utils::kBuffer, + vkcompute::utils::kBuffer); + + // Test with texture storage + test_vulkan_dequantize_affine_impl( + input_sizes, + block_size, + scales, + zero_points, + quant_min, + quant_max, + in_dtype, + out_dtype, + vkcompute::utils::kTexture3D, + vkcompute::utils::kTexture3D); +} + +TEST(VulkanDequantizeAffineTest, test_1d_dequantization) { + if (!vkcompute::api::context() + ->adapter_ptr() + ->has_full_int8_buffers_support()) { + GTEST_SKIP(); + } + // 1D: 1x1x1x12 Tensor, block_size is 3 + test_vulkan_dequantize_affine( + {12}, // input_sizes + {3}, // block_size + {0.1f, 0.2f, 0.15f, 0.25f}, // scales (4 blocks) + {10, -20, 5, 30}, // zero_points (4 blocks) + -128, // quant_min (char min) + 127, // quant_max (char max) + at::kChar, // input dtype + at::kFloat); // output dtype +} + +TEST(VulkanDequantizeAffineTest, test_2d_dequantization) { + if (!vkcompute::api::context() + ->adapter_ptr() + ->has_full_int8_buffers_support()) { + GTEST_SKIP(); + } + // 2D: 1x1x8x6 Tensor, block_size is 1x1x2x3 (8/2=4, 6/3=2, so 4*2=8 blocks) + test_vulkan_dequantize_affine( + {8, 6}, // input_sizes + {2, 3}, // block_size (1/1=1, 1/1=1, 8/2=4, 6/3=2) + {0.1f, 0.2f, 0.15f, 0.25f, 0.3f, 0.05f, 0.4f, 0.35f}, // scales (8 blocks) + {-10, 15, 0, 25, -5, 20, 10, -15}, // zero_points (8 blocks) + -128, // quant_min (char min) + 127, // quant_max (char max) + at::kChar, // input dtype + at::kFloat); // output dtype +} + +TEST(VulkanDequantizeAffineTest, test_3d_dequantization) { + if (!vkcompute::api::context() + ->adapter_ptr() + ->has_full_int8_buffers_support()) { + GTEST_SKIP(); + } + // 3D: 1x6x4x6 Tensor, block_size is 3x2x2 (6/3=2, 4/2=2, 6/2=3, so 2*2*3=12 + // blocks) + test_vulkan_dequantize_affine( + {6, 4, 6}, // input_sizes (changed 7->6 so divisible by 3) + {3, + 2, + 2}, // block_size (6 divisible by 3, 4 divisible by 2, 6 divisible by 2) + {0.1f, + 0.2f, + 0.15f, + 0.25f, + 0.3f, + 0.05f, + 0.4f, + 0.35f, + 0.12f, + 0.18f, + 0.22f, + 0.28f}, // scales (12 blocks) + {-15, 10, 5, -25, 20, -10, 15, -5, 8, -12, 18, -8}, // zero_points (12 + // blocks) + -128, // quant_min (char min) + 127, // quant_max (char max) + at::kChar, // input dtype + at::kFloat); // output dtype +} + +TEST(VulkanDequantizeAffineTest, test_4d_dequantization) { + if (!vkcompute::api::context() + ->adapter_ptr() + ->has_full_int8_buffers_support()) { + GTEST_SKIP(); + } + // 4D: 8x6x6x6 Tensor, block_size is 2x3x2x3 (8/2=4, 6/3=2, 6/2=3, 6/3=2, so + // 4*2*3*2=48 blocks) + test_vulkan_dequantize_affine( + {8, 6, 6, 6}, // input_sizes + {2, 3, 2, 3}, // block_size (8/2=4, 6/3=2, 6/2=3, 6/3=2) + {0.1f, 0.2f, 0.15f, 0.25f, 0.3f, 0.05f, 0.4f, 0.35f, 0.12f, 0.18f, + 0.22f, 0.28f, 0.32f, 0.08f, 0.45f, 0.38f, 0.14f, 0.24f, 0.16f, 0.26f, + 0.34f, 0.06f, 0.44f, 0.36f, 0.11f, 0.21f, 0.13f, 0.23f, 0.31f, 0.07f, + 0.41f, 0.37f, 0.19f, 0.29f, 0.17f, 0.27f, 0.33f, 0.09f, 0.43f, 0.39f, + 0.10f, 0.20f, 0.14f, 0.24f, 0.30f, 0.04f, 0.40f, 0.34f}, // scales (48 + // blocks) + {-20, 10, 5, -15, 25, -10, 15, -5, 8, -12, 18, -8, 22, + -18, 12, -22, -25, 15, 0, -20, 30, -5, 20, -10, 5, -25, + 10, -15, 35, -15, 25, -35, -30, 20, -5, -25, 40, 0, 30, + -40, 10, -30, 15, -10, 45, -20, 35, -45}, // zero_points (48 blocks) + -128, // quant_min (char min) + 127, // quant_max (char max) + at::kChar, // input dtype + at::kFloat); // output dtype +}