diff --git a/backends/vulkan/op_registry.py b/backends/vulkan/op_registry.py index 9a63d178e2d..1f77b30cda3 100644 --- a/backends/vulkan/op_registry.py +++ b/backends/vulkan/op_registry.py @@ -221,13 +221,6 @@ def update_features_impl(op: OpKey): @update_features( [ operator.getitem, - # Quantization related ops will be fused via graph passes - exir_ops.edge.quantized_decomposed.quantize_per_channel.default, - exir_ops.edge.quantized_decomposed.quantize_per_tensor.default, - exir_ops.edge.quantized_decomposed.quantize_per_tensor.tensor, - exir_ops.edge.quantized_decomposed.dequantize_per_tensor.default, - exir_ops.edge.quantized_decomposed.dequantize_per_tensor.tensor, - exir_ops.edge.quantized_decomposed.dequantize_per_channel.default, # Symbolic integer ops torch.ops.aten.sym_size.int, operator.add, @@ -250,6 +243,35 @@ def register_ephemeral_op(features: OpFeatures): return features +@update_features( + [ + exir_ops.edge.quantized_decomposed.quantize_per_channel.default, + exir_ops.edge.quantized_decomposed.quantize_per_tensor.default, + exir_ops.edge.quantized_decomposed.quantize_per_tensor.tensor, + exir_ops.edge.quantized_decomposed.dequantize_per_tensor.default, + exir_ops.edge.quantized_decomposed.dequantize_per_tensor.tensor, + exir_ops.edge.quantized_decomposed.dequantize_per_channel.default, + exir_ops.edge.quantized_decomposed.quantize_per_token.default, + exir_ops.edge.quantized_decomposed.dequantize_per_token.default, + exir_ops.edge.quantized_decomposed.choose_qparams.tensor, + exir_ops.edge.quantized_decomposed.choose_qparams_per_token_asymmetric.default, + ] +) +def register_quantization_op(features: OpFeatures): + # Quantization requires buffer storage and width packing for scales/zero_points + # but we need to provide texture impl features for the partitioner to work properly + features.texture_impl = TextureImplFeatures( + uses_axis_map=True, + valid_packed_dims={ + PackedDim.WIDTH, + }, + ) + features.buffer_impl = True + features.resize_fn = True + features.optimal_storage = VkStorageType.BUFFER + return features + + @update_features( [ exir_ops.edge.aten.add.Tensor, diff --git a/backends/vulkan/runtime/VulkanBackend.cpp b/backends/vulkan/runtime/VulkanBackend.cpp index 7077a9df59c..28e7574537c 100644 --- a/backends/vulkan/runtime/VulkanBackend.cpp +++ b/backends/vulkan/runtime/VulkanBackend.cpp @@ -83,10 +83,14 @@ vkapi::ScalarType get_scalar_type(const vkgraph::VkDataType& vk_datatype) { return vkapi::kChar; case vkgraph::VkDataType::INT32: return vkapi::kInt; + case vkgraph::VkDataType::INT64: + return vkapi::kLong; case vkgraph::VkDataType::FLOAT16: return vkapi::kHalf; case vkgraph::VkDataType::FLOAT32: return vkapi::kFloat; + case vkgraph::VkDataType::FLOAT64: + return vkapi::kDouble; } } diff --git a/backends/vulkan/runtime/graph/ops/glsl/choose_qparams.glslh b/backends/vulkan/runtime/graph/ops/glsl/choose_qparams.glslh index 66620e9b174..d6d27d2e3a3 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/choose_qparams.glslh +++ b/backends/vulkan/runtime/graph/ops/glsl/choose_qparams.glslh @@ -9,15 +9,13 @@ #ifndef CHOOSE_QPARAMS_GLSLH #define CHOOSE_QPARAMS_GLSLH -// equivalent of the eps defined in the cpu implementation -#define SMALL_SCALE_THRESHOLD 6.1e-5 - // Calculate scale and zero point from min and max values void calculate_scale_and_zero_point( float min_val, float max_val, int qmin, int qmax, + float eps_threshold, out float scale_val, out int zero_point_val) { // ensure we have zero included in our range @@ -31,18 +29,18 @@ void calculate_scale_and_zero_point( scale_val = 0.1; } - // Cut off small scale - if (scale_val < SMALL_SCALE_THRESHOLD) { + // Cut off small scale using the provided eps threshold + if (scale_val < eps_threshold) { float org_scale = scale_val; - scale_val = SMALL_SCALE_THRESHOLD; + scale_val = eps_threshold; // Adjust min and max based on new scale if (min_val == 0.0) { - max_val = SMALL_SCALE_THRESHOLD * float(qmax - qmin); + max_val = eps_threshold * float(qmax - qmin); } else if (max_val == 0.0) { - min_val = -SMALL_SCALE_THRESHOLD * float(qmax - qmin); + min_val = -eps_threshold * float(qmax - qmin); } else { - float amplifier = SMALL_SCALE_THRESHOLD / org_scale; + float amplifier = eps_threshold / org_scale; min_val *= amplifier; max_val *= amplifier; } diff --git a/backends/vulkan/runtime/graph/ops/glsl/choose_qparams_buffer.glsl b/backends/vulkan/runtime/graph/ops/glsl/choose_qparams_buffer.glsl index dcbfe493f34..48681a46c30 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/choose_qparams_buffer.glsl +++ b/backends/vulkan/runtime/graph/ops/glsl/choose_qparams_buffer.glsl @@ -29,6 +29,7 @@ $if MODE == "per_tensor": layout(push_constant) uniform restrict Block { int quant_min; int quant_max; + float eps; }; $else: layout(push_constant) uniform restrict Block { @@ -175,7 +176,7 @@ void choose_qparams_per_tensor() { float scale_val; int zero_point_val; - calculate_scale_and_zero_point(global_min, global_max, quant_min, quant_max, scale_val, zero_point_val); + calculate_scale_and_zero_point(global_min, global_max, quant_min, quant_max, eps, scale_val, zero_point_val); t_scale[0] = scale_val; t_zero_point[0] = zero_point_val; @@ -260,7 +261,7 @@ void choose_qparams_per_token() { float scale_val; int zero_point_val; - calculate_scale_and_zero_point(token_min, token_max, quant_min, quant_max, scale_val, zero_point_val); + calculate_scale_and_zero_point(token_min, token_max, quant_min, quant_max, 1e-5, scale_val, zero_point_val); t_scale[token_id] = scale_val; t_zero_point[token_id] = zero_point_val; diff --git a/backends/vulkan/runtime/graph/ops/glsl/choose_qparams_texture.glsl b/backends/vulkan/runtime/graph/ops/glsl/choose_qparams_texture.glsl index 282f1de170a..5076b2d68e9 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/choose_qparams_texture.glsl +++ b/backends/vulkan/runtime/graph/ops/glsl/choose_qparams_texture.glsl @@ -30,6 +30,7 @@ $if MODE == "per_tensor": layout(push_constant) uniform restrict Block { int quant_min; int quant_max; + float eps; }; $else: layout(push_constant) uniform restrict Block { @@ -234,7 +235,7 @@ void choose_qparams_per_tensor() { float scale_val; int zero_point_val; - calculate_scale_and_zero_point(global_min, global_max, quant_min, quant_max, scale_val, zero_point_val); + calculate_scale_and_zero_point(global_min, global_max, quant_min, quant_max, eps, scale_val, zero_point_val); write_texel(t_scale, ivec3(0, 0, 0), vec4(scale_val, 0.0, 0.0, 0.0)); write_texel(t_zero_point, ivec3(0, 0, 0), ivec4(zero_point_val, 0, 0, 0)); @@ -372,7 +373,7 @@ void choose_qparams_per_token() { float scale_val; int zero_point_val; - calculate_scale_and_zero_point(token_min, token_max, quant_min, quant_max, scale_val, zero_point_val); + calculate_scale_and_zero_point(token_min, token_max, quant_min, quant_max, 1e-5, scale_val, zero_point_val); // Convert token_id to 3D coordinates for output texture // Assuming output tensors have the same layout as input but with different dimensions diff --git a/backends/vulkan/runtime/graph/ops/impl/ChooseQParams.cpp b/backends/vulkan/runtime/graph/ops/impl/ChooseQParams.cpp index 1dc2d34afbf..5e9599b91e6 100644 --- a/backends/vulkan/runtime/graph/ops/impl/ChooseQParams.cpp +++ b/backends/vulkan/runtime/graph/ops/impl/ChooseQParams.cpp @@ -150,6 +150,7 @@ void add_choose_qparams_tensor_node( const ValueRef& input, const ValueRef& quant_min, const ValueRef& quant_max, + const ValueRef& eps, const ValueRef& scale_out, const ValueRef& zero_point_out) { std::string kernel_name("choose_qparams_tensor"); @@ -158,6 +159,7 @@ void add_choose_qparams_tensor_node( int quant_min_val = static_cast(graph.get_int(quant_min)); int quant_max_val = static_cast(graph.get_int(quant_max)); + float eps_val = static_cast(graph.get_double(eps)); vkapi::ParamsBindList param_ubos; @@ -180,6 +182,7 @@ void add_choose_qparams_tensor_node( push_constants = { PushConstantDataInfo(&quant_min_val, sizeof(int)), PushConstantDataInfo(&quant_max_val, sizeof(int)), + PushConstantDataInfo(&eps_val, sizeof(float)), }; graph.execute_nodes().emplace_back(new DynamicDispatchNode( @@ -275,8 +278,22 @@ void choose_qparams_tensor_impl( const ValueRef input = args[arg_idx++]; const ValueRef quant_min = args[arg_idx++]; const ValueRef quant_max = args[arg_idx++]; - const ValueRef scale_out = args[arg_idx++]; - const ValueRef zero_point_out = args[arg_idx++]; + const ValueRef eps = args[arg_idx++]; // Added eps parameter (will be voided) + const ValueRef dtype = + args[arg_idx++]; // Added dtype parameter (will be voided) + const ValueRef out_tuple_ref = args[arg_idx++]; + + ValueRef scale_out = kDummyValueRef; + ValueRef zero_point_out = kDummyValueRef; + + { + const ValueListPtr out_tuple = graph.get_value_list(out_tuple_ref); + scale_out = out_tuple->at(0); + zero_point_out = out_tuple->at(1); + } + + // Void the unused dtype parameter to match ATen signature + (void)dtype; // Check tensor types VK_CHECK_COND(graph.val_is_tensor(input)); @@ -289,13 +306,10 @@ void choose_qparams_tensor_impl( graph.dtype_of(input) == vkapi::kHalf || graph.dtype_of(input) == vkapi::kDouble); - // Verify output types - accept CPU types but convert to GPU types - VK_CHECK_COND( - graph.dtype_of(scale_out) == vkapi::kFloat || - graph.dtype_of(scale_out) == vkapi::kDouble); - VK_CHECK_COND( - graph.dtype_of(zero_point_out) == vkapi::kInt || - graph.dtype_of(zero_point_out) == vkapi::kLong); + // Verify output types - only accept Vulkan-supported types + // The Vulkan backend only supports float32 and int32, not float64/int64 + VK_CHECK_COND(graph.dtype_of(scale_out) == vkapi::kFloat); + VK_CHECK_COND(graph.dtype_of(zero_point_out) == vkapi::kInt); // Check that texture storage is width packed if (!graph.is_buffer_storage(input)) { @@ -303,7 +317,7 @@ void choose_qparams_tensor_impl( } add_choose_qparams_tensor_node( - graph, input, quant_min, quant_max, scale_out, zero_point_out); + graph, input, quant_min, quant_max, eps, scale_out, zero_point_out); } void choose_qparams_per_token_asymmetric_impl( @@ -311,8 +325,21 @@ void choose_qparams_per_token_asymmetric_impl( const std::vector& args) { int arg_idx = 0; const ValueRef input = args[arg_idx++]; - const ValueRef scale_out = args[arg_idx++]; - const ValueRef zero_point_out = args[arg_idx++]; + const ValueRef dtype = + args[arg_idx++]; // Added dtype parameter (will be voided) + const ValueRef out_tuple_ref = args[arg_idx++]; + + ValueRef scale_out = kDummyValueRef; + ValueRef zero_point_out = kDummyValueRef; + + { + const ValueListPtr out_tuple = graph.get_value_list(out_tuple_ref); + scale_out = out_tuple->at(0); + zero_point_out = out_tuple->at(1); + } + + // Void the unused parameter to match ATen signature + (void)dtype; // Check tensor types VK_CHECK_COND(graph.val_is_tensor(input)); @@ -325,22 +352,20 @@ void choose_qparams_per_token_asymmetric_impl( graph.dtype_of(input) == vkapi::kHalf || graph.dtype_of(input) == vkapi::kDouble); - // Verify output types - accept CPU types but convert to GPU types - VK_CHECK_COND( - graph.dtype_of(scale_out) == vkapi::kFloat || - graph.dtype_of(scale_out) == vkapi::kDouble); - VK_CHECK_COND( - graph.dtype_of(zero_point_out) == vkapi::kInt || - graph.dtype_of(zero_point_out) == vkapi::kLong); + // Verify output types - only accept Vulkan-supported types + // The Vulkan backend only supports float32 and int32, not float64/int64 + VK_CHECK_COND(graph.dtype_of(scale_out) == vkapi::kFloat); + VK_CHECK_COND(graph.dtype_of(zero_point_out) == vkapi::kInt); add_choose_qparams_per_token_asymmetric_node( graph, input, scale_out, zero_point_out); } REGISTER_OPERATORS { - VK_REGISTER_OP(choose_qparams.tensor, choose_qparams_tensor_impl); VK_REGISTER_OP( - choose_qparams_per_token_asymmetric.default, + quantized_decomposed.choose_qparams.tensor, choose_qparams_tensor_impl); + VK_REGISTER_OP( + quantized_decomposed.choose_qparams_per_token_asymmetric.default, choose_qparams_per_token_asymmetric_impl); } diff --git a/backends/vulkan/runtime/graph/ops/impl/Dequantize.cpp b/backends/vulkan/runtime/graph/ops/impl/Dequantize.cpp index 77a51ce24f9..3838da9a151 100644 --- a/backends/vulkan/runtime/graph/ops/impl/Dequantize.cpp +++ b/backends/vulkan/runtime/graph/ops/impl/Dequantize.cpp @@ -180,8 +180,15 @@ void dequantize_per_tensor_impl( const ValueRef zero_point = args[arg_idx++]; const ValueRef quant_min = args[arg_idx++]; const ValueRef quant_max = args[arg_idx++]; + const ValueRef dtype = args[arg_idx++]; // Added dtype parameter + const ValueRef output_dtype = args[arg_idx++]; // Added output_dtype parameter const ValueRef output = args[arg_idx++]; + // Suppress unused variable warnings - dtype and output_dtype are inferred + // from output + (void)dtype; + (void)output_dtype; + // Check tensor types VK_CHECK_COND(graph.val_is_tensor(input)); VK_CHECK_COND(graph.val_is_tensor(output)); @@ -212,8 +219,15 @@ void dequantize_per_token_impl( const ValueRef zero_point = args[arg_idx++]; const ValueRef quant_min = args[arg_idx++]; const ValueRef quant_max = args[arg_idx++]; + const ValueRef dtype = args[arg_idx++]; // Added dtype parameter + const ValueRef output_dtype = args[arg_idx++]; // Added output_dtype parameter const ValueRef output = args[arg_idx++]; + // Suppress unused variable warnings - dtype and output_dtype are inferred + // from output + (void)dtype; + (void)output_dtype; + // Check tensor types VK_CHECK_COND(graph.val_is_tensor(input)); VK_CHECK_COND(graph.val_is_tensor(scale)); @@ -257,18 +271,34 @@ void dequantize_per_token_impl( const auto scale_sizes = graph.sizes_of(scale); const auto zero_point_sizes = graph.sizes_of(zero_point); - VK_CHECK_COND(scale_sizes.size() == 1); - VK_CHECK_COND(zero_point_sizes.size() == 1); - VK_CHECK_COND(scale_sizes[0] == num_tokens); - VK_CHECK_COND(zero_point_sizes[0] == num_tokens); + // Calculate total number of elements in scale and zero_point tensors + int64_t scale_numel = 1; + for (size_t i = 0; i < scale_sizes.size(); i++) { + scale_numel *= scale_sizes[i]; + } + + int64_t zero_point_numel = 1; + for (size_t i = 0; i < zero_point_sizes.size(); i++) { + zero_point_numel *= zero_point_sizes[i]; + } + + // Check that the total number of elements matches num_tokens + // This allows for both 1D tensors (size [num_tokens]) and reshaped tensors + // (size [num_tokens, 1]) + VK_CHECK_COND(scale_numel == num_tokens); + VK_CHECK_COND(zero_point_numel == num_tokens); add_dequantize_per_token_node( graph, input, scale, zero_point, quant_min, quant_max, output); } REGISTER_OPERATORS { - VK_REGISTER_OP(dequantize_per_tensor.default, dequantize_per_tensor_impl); - VK_REGISTER_OP(dequantize_per_token.default, dequantize_per_token_impl); + VK_REGISTER_OP( + quantized_decomposed.dequantize_per_tensor.default, + dequantize_per_tensor_impl); + VK_REGISTER_OP( + quantized_decomposed.dequantize_per_token.default, + dequantize_per_token_impl); } } // namespace vkcompute diff --git a/backends/vulkan/runtime/graph/ops/impl/Quantize.cpp b/backends/vulkan/runtime/graph/ops/impl/Quantize.cpp index 49277b4d718..f8f930bf0fb 100644 --- a/backends/vulkan/runtime/graph/ops/impl/Quantize.cpp +++ b/backends/vulkan/runtime/graph/ops/impl/Quantize.cpp @@ -180,8 +180,12 @@ void quantize_per_tensor_impl( const ValueRef zero_point = args[arg_idx++]; const ValueRef quant_min = args[arg_idx++]; const ValueRef quant_max = args[arg_idx++]; + const ValueRef dtype = args[arg_idx++]; // Added dtype parameter const ValueRef output = args[arg_idx++]; + // Suppress unused variable warning - dtype is inferred from output + (void)dtype; + // Check tensor types VK_CHECK_COND(graph.val_is_tensor(input)); VK_CHECK_COND(graph.val_is_tensor(output)); @@ -205,8 +209,12 @@ void quantize_per_token_impl( const ValueRef zero_point = args[arg_idx++]; const ValueRef quant_min = args[arg_idx++]; const ValueRef quant_max = args[arg_idx++]; + const ValueRef dtype = args[arg_idx++]; // Added dtype parameter const ValueRef output = args[arg_idx++]; + // Suppress unused variable warning - dtype is inferred from output + (void)dtype; + // Check tensor types VK_CHECK_COND(graph.val_is_tensor(input)); VK_CHECK_COND(graph.val_is_tensor(scale)); @@ -243,18 +251,33 @@ void quantize_per_token_impl( const auto scale_sizes = graph.sizes_of(scale); const auto zero_point_sizes = graph.sizes_of(zero_point); - VK_CHECK_COND(scale_sizes.size() == 1); - VK_CHECK_COND(zero_point_sizes.size() == 1); - VK_CHECK_COND(scale_sizes[0] == num_tokens); - VK_CHECK_COND(zero_point_sizes[0] == num_tokens); + // Calculate total number of elements in scale and zero_point tensors + int64_t scale_numel = 1; + for (size_t i = 0; i < scale_sizes.size(); i++) { + scale_numel *= scale_sizes[i]; + } + + int64_t zero_point_numel = 1; + for (size_t i = 0; i < zero_point_sizes.size(); i++) { + zero_point_numel *= zero_point_sizes[i]; + } + + // Check that the total number of elements matches num_tokens + // This allows for both 1D tensors (size [num_tokens]) and reshaped tensors + // (size [num_tokens, 1]) + VK_CHECK_COND(scale_numel == num_tokens); + VK_CHECK_COND(zero_point_numel == num_tokens); add_quantize_per_token_node( graph, input, scale, zero_point, quant_min, quant_max, output); } REGISTER_OPERATORS { - VK_REGISTER_OP(quantize_per_tensor.default, quantize_per_tensor_impl); - VK_REGISTER_OP(quantize_per_token.default, quantize_per_token_impl); + VK_REGISTER_OP( + quantized_decomposed.quantize_per_tensor.default, + quantize_per_tensor_impl); + VK_REGISTER_OP( + quantized_decomposed.quantize_per_token.default, quantize_per_token_impl); } } // namespace vkcompute diff --git a/backends/vulkan/serialization/schema.fbs b/backends/vulkan/serialization/schema.fbs index f112581c498..99ba6a86594 100644 --- a/backends/vulkan/serialization/schema.fbs +++ b/backends/vulkan/serialization/schema.fbs @@ -18,6 +18,8 @@ enum VkDataType : byte { INT32 = 3, FLOAT16 = 4, FLOAT32 = 5, + FLOAT64 = 6, + INT64 = 7, } // Describes what kind of GPU resource should be used to represent a tensor. The diff --git a/backends/vulkan/serialization/vulkan_graph_builder.py b/backends/vulkan/serialization/vulkan_graph_builder.py index 5bae0475c28..cd876bd6305 100644 --- a/backends/vulkan/serialization/vulkan_graph_builder.py +++ b/backends/vulkan/serialization/vulkan_graph_builder.py @@ -45,9 +45,11 @@ def __init__( self, program: ExportedProgram, delegate_mapping_builder: DelegateMappingBuilder, + downcast_64_bit: bool = True, ) -> None: self.program = program self.delegate_mapping_builder = delegate_mapping_builder + self.downcast_64_bit = downcast_64_bit self.chain = [] self.values = [] self.input_ids = [] @@ -72,13 +74,14 @@ def get_vk_datatype(torch_dtype: torch.dtype) -> vk_graph_schema.VkDataType: return vk_graph_schema.VkDataType.INT8 elif torch_dtype == torch.int32: return vk_graph_schema.VkDataType.INT32 + elif torch_dtype == torch.int64: + return vk_graph_schema.VkDataType.INT64 elif torch_dtype == torch.float16: return vk_graph_schema.VkDataType.FLOAT16 elif torch_dtype == torch.float32: return vk_graph_schema.VkDataType.FLOAT32 - # Narrowing conversion for index tensor produced by max_poolNd_with_indices. - elif torch_dtype == torch.int64: - return vk_graph_schema.VkDataType.INT32 + elif torch_dtype == torch.float64: + return vk_graph_schema.VkDataType.FLOAT64 else: raise AssertionError(f"Invalid dtype for vulkan_preprocess ({torch_dtype})") @@ -201,11 +204,20 @@ def create_tensor_value(self, spec: TensorSpec, constant_id: int = -1) -> int: # pyre-ignore[16] memory_layout = spec.vk_memory_layout + # Apply downcast logic before getting VK datatype + effective_dtype = spec.dtype + if self.downcast_64_bit and spec.dtype == torch.float64: + effective_dtype = torch.float32 + elif self.downcast_64_bit and spec.dtype == torch.int64: + effective_dtype = torch.int32 + + datatype = self.get_vk_datatype(effective_dtype) + new_id = len(self.values) self.values.append( vk_graph_schema.VkValue( value=vk_graph_schema.VkTensor( - datatype=self.get_vk_datatype(spec.dtype), + datatype=datatype, dims=spec.shape, constant_id=constant_id, mem_obj_id=mem_obj_id, diff --git a/backends/vulkan/serialization/vulkan_graph_schema.py b/backends/vulkan/serialization/vulkan_graph_schema.py index 35113bc623a..f845e5601a7 100644 --- a/backends/vulkan/serialization/vulkan_graph_schema.py +++ b/backends/vulkan/serialization/vulkan_graph_schema.py @@ -29,6 +29,8 @@ class VkDataType(IntEnum): INT32 = 3 FLOAT16 = 4 FLOAT32 = 5 + FLOAT64 = 6 + INT64 = 7 class VkStorageType(IntEnum): diff --git a/backends/vulkan/test/op_tests/choose_qparams_test.cpp b/backends/vulkan/test/op_tests/choose_qparams_test.cpp index 55e96151387..75b7cbc8960 100644 --- a/backends/vulkan/test/op_tests/choose_qparams_test.cpp +++ b/backends/vulkan/test/op_tests/choose_qparams_test.cpp @@ -433,14 +433,23 @@ void test_vulkan_choose_qparams_tensor_impl( const ValueRef r_scale = graph.add_tensor({}, vkapi::kFloat, out_storage); const ValueRef r_zero_point = graph.add_tensor({}, vkapi::kInt, out_storage); - VK_GET_OP_FN("choose_qparams.tensor") + // Create output tuple + const ValueRef r_out_tuple = graph.add_value_list({r_scale, r_zero_point}); + + // Add eps and dtype parameters to match ATen signature + const ValueRef r_eps = graph.add_scalar(6.1e-5); + const ValueRef r_dtype = + graph.add_scalar(static_cast(dtype)); + + VK_GET_OP_FN("quantized_decomposed.choose_qparams.tensor") (graph, { r_input.value, r_quant_min, r_quant_max, - r_scale, - r_zero_point, + r_eps, + r_dtype, + r_out_tuple, }); ValueRef staging_scale = graph.set_output_tensor(r_scale); @@ -647,12 +656,20 @@ void test_vulkan_choose_qparams_per_token_asymmetric_impl( const ValueRef r_zero_point = graph.add_tensor(output_sizes, vkapi::kInt, out_storage); - VK_GET_OP_FN("choose_qparams_per_token_asymmetric.default") + // Create output tuple + const ValueRef r_out_tuple = graph.add_value_list({r_scale, r_zero_point}); + + // Add dtype parameter to match ATen signature + const ValueRef r_dtype = + graph.add_scalar(static_cast(dtype)); + + VK_GET_OP_FN( + "quantized_decomposed.choose_qparams_per_token_asymmetric.default") (graph, { r_input.value, - r_scale, - r_zero_point, + r_dtype, + r_out_tuple, }); ValueRef staging_scale = graph.set_output_tensor(r_scale); diff --git a/backends/vulkan/test/op_tests/dequantize_test.cpp b/backends/vulkan/test/op_tests/dequantize_test.cpp index 6c604076c41..82f316abe82 100644 --- a/backends/vulkan/test/op_tests/dequantize_test.cpp +++ b/backends/vulkan/test/op_tests/dequantize_test.cpp @@ -585,7 +585,10 @@ void test_vulkan_dequantize_per_tensor_impl( const ValueRef r_out = graph.add_tensor( input.sizes().vec(), from_at_scalartype(out_dtype), out_storage); - VK_GET_OP_FN("dequantize_per_tensor.default") + const ValueRef r_dtype = + graph.add_scalar(static_cast(out_dtype)); + + VK_GET_OP_FN("quantized_decomposed.dequantize_per_tensor.default") (graph, { r_input.value, @@ -593,6 +596,8 @@ void test_vulkan_dequantize_per_tensor_impl( r_zero_point, r_quant_min, r_quant_max, + r_dtype, + r_dtype, r_out, }); @@ -1046,7 +1051,10 @@ void test_vulkan_dequantize_per_token_impl( const ValueRef r_out = graph.add_tensor( input.sizes().vec(), from_at_scalartype(out_dtype), out_storage); - VK_GET_OP_FN("dequantize_per_token.default") + const ValueRef r_dtype = + graph.add_scalar(static_cast(out_dtype)); + + VK_GET_OP_FN("quantized_decomposed.dequantize_per_token.default") (graph, { r_input.value, @@ -1054,6 +1062,8 @@ void test_vulkan_dequantize_per_token_impl( r_zero_point.value, r_quant_min, r_quant_max, + r_dtype, + r_dtype, r_out, }); diff --git a/backends/vulkan/test/op_tests/quantize_test.cpp b/backends/vulkan/test/op_tests/quantize_test.cpp index 150bda6989e..8c5246f6c0c 100644 --- a/backends/vulkan/test/op_tests/quantize_test.cpp +++ b/backends/vulkan/test/op_tests/quantize_test.cpp @@ -48,6 +48,16 @@ Tensor& quantize_per_token_out( ScalarType dtype, Tensor& out); +Tensor& quantize_per_channel_out( + const Tensor& input, + const Tensor& scale, + const Tensor& zero_point, + int64_t axis, + int64_t quant_min, + int64_t quant_max, + ScalarType dtype, + Tensor& out); + // Wrapper function for quantize_per_tensor_out without context Tensor& quantize_per_tensor_out_no_context( const Tensor& input, @@ -74,6 +84,20 @@ Tensor& quantize_per_token_out_no_context( input, scale, zero_point, quant_min, quant_max, dtype, out); } +// Wrapper function for quantize_per_channel_out without context +Tensor& quantize_per_channel_out_no_context( + const Tensor& input, + const Tensor& scale, + const Tensor& zero_point, + int64_t axis, + int64_t quant_min, + int64_t quant_max, + ScalarType dtype, + Tensor& out) { + return torch::executor::native::quantize_per_channel_out( + input, scale, zero_point, axis, quant_min, quant_max, dtype, out); +} + // ATen wrapper for quantize_per_tensor at::Tensor quantize_per_tensor_aten( const at::Tensor& input, @@ -106,6 +130,23 @@ at::Tensor quantize_per_token_aten( return out; } +// ATen wrapper for quantize_per_channel +at::Tensor quantize_per_channel_aten( + const at::Tensor& input, + const at::Tensor& scale, + const at::Tensor& zero_point, + int64_t axis, + int64_t quant_min, + int64_t quant_max, + at::ScalarType dtype) { + auto out = at::empty_like(input, dtype); + ScalarType et_dtype = at_scalartype_to_et_scalartype(dtype); + + WRAP_TO_ATEN(quantize_per_channel_out_no_context, 7) + (input, scale, zero_point, axis, quant_min, quant_max, et_dtype, out); + return out; +} + } // namespace native } // namespace executor } // namespace torch @@ -160,6 +201,40 @@ void check_quantize_args( quant_max); } +/** + * Helper function to validate quantize_per_channel arguments + * Similar to the validation in op_quantize.cpp + */ +void check_quantize_per_channel_args( + const std::vector& input_sizes, + const std::vector& scales, + const std::vector& zero_points, + int64_t axis) { + // Normalize axis + int64_t normalized_axis = axis; + if (normalized_axis < 0) { + normalized_axis += input_sizes.size(); + } + + ASSERT_GE(normalized_axis, 0) + << "axis " << axis << " is not legal, normalized axis " << normalized_axis + << " should be >= 0"; + + ASSERT_LT(normalized_axis, static_cast(input_sizes.size())) + << "axis " << axis << " is not legal, normalized axis " << normalized_axis + << " should be < input.dim() " << input_sizes.size(); + + int64_t num_channels = input_sizes[normalized_axis]; + + ASSERT_EQ(num_channels, static_cast(scales.size())) + << "Expected scales.size() to match input.size(axis) (" << num_channels + << "), but got " << scales.size(); + + ASSERT_EQ(num_channels, static_cast(zero_points.size())) + << "Expected zero_points.size() to match input.size(axis) (" + << num_channels << "), but got " << zero_points.size(); +} + // // Reference Implementation // @@ -271,6 +346,110 @@ at::Tensor quantize_per_token_reference_impl( return out; } +/* + * Reference implementation of quantize_per_channel + */ +at::Tensor quantize_per_channel_reference_impl( + const at::Tensor& input, + const at::Tensor& scale, + const at::Tensor& zero_point, + int64_t axis, + int64_t quant_min, + int64_t quant_max, + at::ScalarType dtype) { + // Normalize axis to handle negative values + int64_t normalized_axis = axis; + if (normalized_axis < 0) { + normalized_axis += input.dim(); + } + + // Create output tensor with the same shape as input but with target dtype + at::Tensor output = at::empty_like(input, dtype); + + // Get the number of channels along the quantization axis + int64_t num_channels = input.size(normalized_axis); + + // Calculate strides for efficient indexing + std::vector input_strides; + std::vector input_sizes; + for (int64_t i = 0; i < input.dim(); i++) { + input_sizes.push_back(input.size(i)); + input_strides.push_back(input.stride(i)); + } + + // Get data pointers + const float* input_data = input.const_data_ptr(); + const double* scale_data = scale.const_data_ptr(); + const int64_t* zero_point_data = zero_point.const_data_ptr(); + + // Iterate through all elements in the tensor + int64_t total_elements = input.numel(); + + // Helper lambda to convert flat index to multi-dimensional coordinates + auto flat_to_coords = [&](int64_t flat_idx, std::vector& coords) { + int64_t remaining = flat_idx; + for (int64_t dim = input.dim() - 1; dim >= 0; dim--) { + coords[dim] = remaining % input_sizes[dim]; + remaining /= input_sizes[dim]; + } + }; + + // Process each element + std::vector coords(input.dim()); + for (int64_t flat_idx = 0; flat_idx < total_elements; flat_idx++) { + // Convert flat index to coordinates + flat_to_coords(flat_idx, coords); + + // Get the channel index for this element + int64_t channel_idx = coords[normalized_axis]; + + // Get the quantization parameters for this channel + double channel_scale = scale_data[channel_idx]; + int64_t channel_zero_point = zero_point_data[channel_idx]; + + // Get the input value + float input_value = input_data[flat_idx]; + + // Apply quantization formula: round(input / scale) + zero_point + float inv_scale = 1.0f / static_cast(channel_scale); + int64_t quantized_value = static_cast( + static_cast(channel_zero_point) + + std::nearbyint(static_cast(inv_scale * input_value))); + + // Clamp to quantization bounds + quantized_value = std::max(quantized_value, quant_min); + quantized_value = std::min(quantized_value, quant_max); + + // Store the result based on output dtype + switch (dtype) { + case at::kByte: { + uint8_t* output_data = output.mutable_data_ptr(); + output_data[flat_idx] = static_cast(quantized_value); + break; + } + case at::kChar: { + int8_t* output_data = output.mutable_data_ptr(); + output_data[flat_idx] = static_cast(quantized_value); + break; + } + case at::kShort: { + int16_t* output_data = output.mutable_data_ptr(); + output_data[flat_idx] = static_cast(quantized_value); + break; + } + case at::kInt: { + int32_t* output_data = output.mutable_data_ptr(); + output_data[flat_idx] = static_cast(quantized_value); + break; + } + default: + assert(false && "Unsupported output dtype"); + } + } + + return output; +} + // Forward declaration of implementation functions void test_vulkan_quantize_per_tensor_impl( const std::vector& input_sizes, @@ -476,7 +655,10 @@ void test_vulkan_quantize_per_tensor_impl( const ValueRef r_out = graph.add_tensor( input.sizes().vec(), from_at_scalartype(dtype), out_storage); - VK_GET_OP_FN("quantize_per_tensor.default") + const ValueRef r_dtype = + graph.add_scalar(static_cast(dtype)); + + VK_GET_OP_FN("quantized_decomposed.quantize_per_tensor.default") (graph, { r_input.value, @@ -484,6 +666,7 @@ void test_vulkan_quantize_per_tensor_impl( r_zero_point, r_quant_min, r_quant_max, + r_dtype, r_out, }); @@ -509,7 +692,10 @@ void test_vulkan_quantize_per_tensor_impl( at::Tensor reference_int = reference_out.to(at::kInt); at::Tensor vk_int = vk_out.to(at::kInt); - const bool output_correct = at::allclose(reference_int, vk_int); + // Tolerance is 1 to address rounding errors and fp math differences between + // CPU/GPU + const bool output_correct = + at::allclose(reference_int, vk_int, /*rtol=*/1, /*atol=*/1); if (!output_correct) { at::Tensor diffs = at::abs(reference_int - vk_int); @@ -835,7 +1021,10 @@ void test_vulkan_quantize_per_token_impl( const ValueRef r_out = graph.add_tensor( input.sizes().vec(), from_at_scalartype(dtype), out_storage); - VK_GET_OP_FN("quantize_per_token.default") + const ValueRef r_dtype = + graph.add_scalar(static_cast(dtype)); + + VK_GET_OP_FN("quantized_decomposed.quantize_per_token.default") (graph, { r_input.value, @@ -843,6 +1032,7 @@ void test_vulkan_quantize_per_token_impl( r_zero_point.value, r_quant_min, r_quant_max, + r_dtype, r_out, }); @@ -881,7 +1071,10 @@ void test_vulkan_quantize_per_token_impl( at::Tensor reference_int = reference_out.to(at::kInt); at::Tensor vk_int = vk_out.to(at::kInt); - const bool output_correct = at::allclose(reference_int, vk_int); + // Tolerance is 1 to address rounding errors and fp math differences between + // CPU/GPU + const bool output_correct = + at::allclose(reference_int, vk_int, /*rtol=*/1, /*atol=*/1); if (!output_correct) { at::Tensor diffs = at::abs(reference_int - vk_int); @@ -916,7 +1109,7 @@ void test_vulkan_quantize_per_token_impl( } TEST( - VulkanQuantizePerTensorTest, + VulkanQuantizePerTokenTest, test_reference_quantize_per_token_float_to_int8) { std::vector scales = {0.1, 0, 0.3, 0.1, 0.2, 0.3}; std::vector zero_points = {1, 2, 3, 0, -1, -2}; @@ -932,7 +1125,7 @@ TEST( } TEST( - VulkanQuantizePerTensorTest, + VulkanQuantizePerTokenTest, test_reference_quantize_per_token_float_to_int32) { std::vector scales = {0.1, 0, 0.3, 0.1, 0.2, 0.3}; std::vector zero_points = {1, 2, 3, 0, -1, -2}; @@ -948,7 +1141,7 @@ TEST( } TEST( - VulkanQuantizePerTensorTest, + VulkanQuantizePerTokenTest, test_reference_quantize_per_token_half_to_int32) { std::vector scales = {0.1, 0, 0.3, 0.1, 0.2, 0.3}; std::vector zero_points = {1, 2, 3, 0, -1, -2}; @@ -964,7 +1157,7 @@ TEST( } TEST( - VulkanQuantizePerTensorTest, + VulkanQuantizePerTokenTest, test_reference_quantize_per_token_half_to_uint8) { std::vector scales = {0.1, 0, 0.3, 0.1, 0.2, 0.3}; std::vector zero_points = {1, 2, 3, 0, -1, -2}; @@ -980,7 +1173,7 @@ TEST( } TEST( - VulkanQuantizePerTensorTest, + VulkanQuantizePerTokenTest, test_vulkan_quantize_per_token_float_to_uint8) { if (!vkcompute::api::context() ->adapter_ptr() @@ -1001,9 +1194,7 @@ TEST( at::kByte); } -TEST( - VulkanQuantizePerTensorTest, - test_vulkan_quantize_per_token_float_to_int8) { +TEST(VulkanQuantizePerTokenTest, test_vulkan_quantize_per_token_float_to_int8) { if (!vkcompute::api::context() ->adapter_ptr() ->has_full_int8_buffers_support()) { @@ -1024,7 +1215,7 @@ TEST( } TEST( - VulkanQuantizePerTensorTest, + VulkanQuantizePerTokenTest, test_vulkan_quantize_per_token_float_to_int32) { std::vector scales = { -0.5, -0.3, -0.2, 0, 0.1, 0.8, 0.1, 0.2, 0.3, 0.4}; @@ -1041,7 +1232,7 @@ TEST( } TEST( - VulkanQuantizePerTensorTest, + VulkanQuantizePerTokenTest, test_vulkan_quantize_per_token_float_to_int32_small_scales) { std::vector scales = { 0, @@ -1062,7 +1253,7 @@ TEST( } TEST( - VulkanQuantizePerTensorTest, + VulkanQuantizePerTokenTest, test_vulkan_quantize_per_token_float_to_uint8_many_tokens) { if (!vkcompute::api::context() ->adapter_ptr() @@ -1087,7 +1278,7 @@ TEST( at::kByte); } -TEST(VulkanQuantizePerTensorTest, test_vulkan_quantize_per_token_half_to_int8) { +TEST(VulkanQuantizePerTokenTest, test_vulkan_quantize_per_token_half_to_int8) { if (!vkcompute::api::context() ->adapter_ptr() ->has_full_float16_buffers_support()) { @@ -1107,7 +1298,7 @@ TEST(VulkanQuantizePerTensorTest, test_vulkan_quantize_per_token_half_to_int8) { } TEST( - VulkanQuantizePerTensorTest, + VulkanQuantizePerTokenTest, test_vulkan_quantize_per_token_double_to_int8) { if (!vkcompute::api::context() ->adapter_ptr() @@ -1126,3 +1317,189 @@ TEST( at::kDouble, // input dtype at::kChar); // output dtype } + +void test_reference_quantize_per_channel( + const std::vector& input_sizes, + const std::vector& pre_scales, + const std::vector& zero_points, + int64_t axis, + int64_t quant_min, + int64_t quant_max, + at::ScalarType in_dtype = at::kFloat, + at::ScalarType dtype = at::kInt) { + check_quantize_args(quant_min, quant_max, dtype); + check_quantize_per_channel_args(input_sizes, pre_scales, zero_points, axis); + + std::vector input_sizes_int64( + input_sizes.begin(), input_sizes.end()); + at::Tensor input = + at::zeros(input_sizes_int64, at::device(at::kCPU).dtype(in_dtype)); + + // Fill with a simple pattern: values from 0 to 1 in steps + float step = 1.0f / (input.numel() - 1); + auto flat_input = input.flatten(); + for (int i = 0; i < flat_input.numel(); i++) { + flat_input[i] = i * step; + } + + // Reshape back to original dimensions + input = flat_input.reshape(input_sizes_int64); + + std::vector scales = pre_scales; + for (auto& s : scales) { + s = s < eps ? eps : s; + } + + // Create scale and zero_point tensors + at::Tensor scale_tensor = + at::tensor(scales, at::device(at::kCPU).dtype(at::kDouble)); + at::Tensor zero_point_tensor = + at::tensor(zero_points, at::device(at::kCPU).dtype(at::kLong)); + + // Get reference output + at::Tensor my_ref = quantize_per_channel_reference_impl( + input, + scale_tensor, + zero_point_tensor, + axis, + quant_min, + quant_max, + dtype); + + // Get implementation output + at::Tensor cpu_ref = torch::executor::native::quantize_per_channel_aten( + input, + scale_tensor, + zero_point_tensor, + axis, + quant_min, + quant_max, + dtype); + + // Get direct ATen implementation output + c10::ScalarType aten_dtype = dtype; + if (dtype == at::kChar) { + aten_dtype = c10::kQInt8; + } else if (dtype == at::kByte) { + aten_dtype = c10::kQUInt8; + } + + // Normalize axis for ATen (it doesn't handle negative values) + int64_t normalized_axis = axis; + if (normalized_axis < 0) { + normalized_axis += input.dim(); + } + + at::Tensor aten_ref = at::quantize_per_channel( + input, scale_tensor, zero_point_tensor, normalized_axis, aten_dtype); + + // Convert to int for consistent display regardless of underlying type + at::Tensor my_ref_int = my_ref.to(at::kInt); + at::Tensor cpu_ref_int = cpu_ref.to(at::kInt); + // For quantized tensors, we need to use int_repr() to get the underlying + // integer values + at::Tensor aten_ref_int = aten_ref.int_repr().to(at::kInt); + + const bool output_correct = at::equal(my_ref_int, cpu_ref_int); + if (!output_correct) { + std::cout << "\n" + << "Failed with parameters: " << std::endl; + std::cout << " axis: " << axis << std::endl; + std::cout << " input sizes:"; + for (size_t i = 0; i < input_sizes.size(); i++) { + std::cout << " " << input_sizes[i] << " "; + } + std::cout << "" << std::endl; + std::cout << " scale(s):"; + for (size_t i = 0; i < scales.size(); i++) { + std::cout << " " << scales[i] << " "; + } + std::cout << "" << std::endl; + std::cout << " zero_point(s):"; + for (size_t i = 0; i < zero_points.size(); i++) { + std::cout << " " << zero_points[i] << " "; + } + std::cout << "" << std::endl; + std::cout << " quant_min: " << quant_min << std::endl; + std::cout << " quant_max: " << quant_max << std::endl; + + std::cout << "input:" << std::endl; + std::cout << input << std::endl; + std::cout << "aten_ref:" << std::endl; + std::cout << aten_ref_int << std::endl; + std::cout << "cpu_ref:" << std::endl; + std::cout << cpu_ref_int << std::endl; + std::cout << "my_ref:" << std::endl; + std::cout << my_ref_int << std::endl; + } + + ASSERT_TRUE(output_correct); +} + +TEST( + VulkanQuantizePerChannelTest, + test_reference_quantize_per_channel_float_to_int8_3D_axis0) { + std::vector scales = {0.1, 0.2, 0.3}; + std::vector zero_points = {0, 5, -2}; + + test_reference_quantize_per_channel( + {3, 4, 2}, // input sizes + scales, + zero_points, + 0, // axis + -128, // quant_min + 127, // quant_max + at::kFloat, + at::kChar); +} + +TEST( + VulkanQuantizePerChannelTest, + test_reference_quantize_per_channel_float_to_int8_3D_axis2) { + std::vector scales = {0.1, 0.2}; + std::vector zero_points = {0, 5}; + + test_reference_quantize_per_channel( + {3, 4, 2}, // input sizes + scales, + zero_points, + 2, // axis + -128, // quant_min + 127, // quant_max + at::kFloat, + at::kChar); +} + +TEST( + VulkanQuantizePerChannelTest, + test_reference_quantize_per_channel_float_to_int8_3D_axisn1) { + std::vector scales = {0.1, 0.2}; + std::vector zero_points = {0, 5}; + + test_reference_quantize_per_channel( + {3, 4, 2}, // input sizes + scales, + zero_points, + -1, // axis + -128, // quant_min + 127, // quant_max + at::kFloat, + at::kChar); +} + +TEST( + VulkanQuantizePerChannelTest, + test_reference_quantize_per_channel_float_to_int8_4D_axis0) { + std::vector scales = {0.1, 0.2, 0.00002}; + std::vector zero_points = {0, 5, -4}; + + test_reference_quantize_per_channel( + {3, 4, 2, 5}, // input sizes + scales, + zero_points, + 0, // axis + -128, // quant_min + 127, // quant_max + at::kFloat, + at::kChar); +} diff --git a/backends/vulkan/vulkan_preprocess.py b/backends/vulkan/vulkan_preprocess.py index a22afc3f42e..a6d5737dbb8 100644 --- a/backends/vulkan/vulkan_preprocess.py +++ b/backends/vulkan/vulkan_preprocess.py @@ -67,7 +67,6 @@ # pyre-ignore def apply_passes(program: ExportedProgram, passes) -> ExportedProgram: for p in passes: - if issubclass(type(p), ExportPass) or issubclass(type(p), PassBase): new_gm = program.graph_module # This is a workaround to allow the memory planning pass to work without @@ -110,6 +109,9 @@ def parse_compile_spec(compile_specs: List[CompileSpec]) -> Dict[str, Any]: if spec.key == "skip_tag_memory_metadata": options[spec.key] = bool.from_bytes(spec.value, byteorder="little") + if spec.key == "downcast_64_bit": + options[spec.key] = bool.from_bytes(spec.value, byteorder="little") + # Unhandled options are ignored return options @@ -142,6 +144,7 @@ def preprocess( # noqa: C901 default_memory_layout = compile_options.get( "memory_layout_override", VkMemoryLayout.TENSOR_WIDTH_PACKED ) + downcast_64_bit = compile_options.get("downcast_64_bit", True) program = unsafe_remove_auto_functionalized_pass(program) @@ -213,7 +216,9 @@ def preprocess( # noqa: C901 ) graph_builder = VkGraphBuilder( - program, DelegateMappingBuilder(generated_identifiers=True) + program, + DelegateMappingBuilder(generated_identifiers=True), + downcast_64_bit=downcast_64_bit, ) vk_graph = graph_builder.build_graph() diff --git a/kernels/quantized/cpu/op_quantize.cpp b/kernels/quantized/cpu/op_quantize.cpp index d0b7c882f8e..5586f8a77eb 100644 --- a/kernels/quantized/cpu/op_quantize.cpp +++ b/kernels/quantized/cpu/op_quantize.cpp @@ -6,7 +6,6 @@ * LICENSE file in the root directory of this source tree. */ -#include #include #include #include @@ -282,55 +281,34 @@ Tensor& quantize_per_channel_out( check_quantize_per_tensor_args(input, quant_min, quant_max, dtype, out); - // a list contains all dimensions except axis - int64_t dims[kTensorDimensionLimit]; - for (int64_t i = 0; i < input.dim() - 1; i++) { - if (i < axis) { - dims[i] = i; - } else { - dims[i] = i - 1; - } - } const double* scale_data = scale.const_data_ptr(); const int64_t* zero_point_data = zero_point.const_data_ptr(); - std::optional> optional_dim_list{ - executorch::aten::ArrayRef{dims, size_t(input.dim() - 1)}}; - - // Actual quantization logic - // input, out are the input and output tensors - // channel_ix is the index along the axis dimension. 0 <= channel_ix < - // input.size(axis). - // i.e. if the tensor has shape (N,C,H,W), axis being 1, then channel_ix - // will be 0, 1, 2, ... C-1 - // in_ix is the flat index of the element you are quantizing. - // in other words you are quantizing in_data[in_ix] + // High-performance single loop with direct channel calculation #define QUANTIZE_IMPL(CTYPE_IN, CTYPE_OUT, out_dtype) \ - case ScalarType::out_dtype: \ - for (size_t channel_ix = 0; channel_ix < input.size(axis); ++channel_ix) { \ - double _scale = scale_data[channel_ix]; \ - int64_t _zero_point = zero_point_data[channel_ix]; \ - auto* out_data_ptr = out.mutable_data_ptr(); \ - const auto* input_data_ptr = input.const_data_ptr(); \ - apply_over_dim_list( \ - [input_data_ptr, \ - out_data_ptr, \ - _scale, \ - _zero_point, \ - quant_min, \ - quant_max](size_t in_ix) { \ - out_data_ptr[in_ix] = quantize_val( \ - _scale, \ - _zero_point, \ - input_data_ptr[in_ix], \ - quant_min, \ - quant_max); \ - }, \ - input, \ - optional_dim_list, \ - channel_ix); \ + case ScalarType::out_dtype: { \ + auto* out_data_ptr = out.mutable_data_ptr(); \ + const auto* input_data_ptr = input.const_data_ptr(); \ + const int64_t input_numel = input.numel(); \ + const int64_t axis_size = input.size(axis); \ + /* Calculate the stride pattern for efficient channel index calculation */ \ + int64_t axis_block_size = 1; \ + for (int64_t i = axis + 1; i < input.dim(); i++) { \ + axis_block_size *= input.size(i); \ } \ - break; + /* Single loop over all elements */ \ + for (int64_t i = 0; i < input_numel; i++) { \ + /* Calculate which channel this element belongs to */ \ + int64_t channel_idx = (i / axis_block_size) % axis_size; \ + /* Get quantization parameters for this channel */ \ + double _scale = scale_data[channel_idx]; \ + int64_t _zero_point = zero_point_data[channel_idx]; \ + /* Apply quantization */ \ + out_data_ptr[i] = quantize_val( \ + _scale, _zero_point, input_data_ptr[i], quant_min, quant_max); \ + } \ + } break; + #define CALCULATE_FLOAT_TYPE(CTYPE_IN, in_dtype) \ case ScalarType::in_dtype: \ switch (out.scalar_type()) { \ diff --git a/kernels/quantized/cpu/targets.bzl b/kernels/quantized/cpu/targets.bzl index 3ba9715506a..f29f1f013b7 100644 --- a/kernels/quantized/cpu/targets.bzl +++ b/kernels/quantized/cpu/targets.bzl @@ -51,12 +51,6 @@ _QUANT_OPS = ( ), op_target( name = "op_quantize", - deps = [ - "//executorch/kernels/portable/cpu/util:reduce_util", - ], - _aten_mode_deps = [ - "//executorch/kernels/portable/cpu/util:reduce_util_aten", - ], ), ) diff --git a/kernels/quantized/test/op_quantize_test.cpp b/kernels/quantized/test/op_quantize_test.cpp index 5cd17223d80..4ac835c24ce 100644 --- a/kernels/quantized/test/op_quantize_test.cpp +++ b/kernels/quantized/test/op_quantize_test.cpp @@ -206,3 +206,243 @@ TEST(OpQuantizeOutTest, QuantizePerChannel) { EXPECT_TENSOR_EQ(out, expected); } + +TEST(OpQuantizeOutTest, QuantizePerChannelAxis0) { + TensorFactory tf_float; + TensorFactory tf_double; + TensorFactory tf_long; + + Tensor input = tf_float.full({3, 2}, 4); + Tensor scale = tf_double.make({3}, {0.5, 1.0, 2.0}); + Tensor zero_point = tf_long.make({3}, {100, 50, 25}); + int64_t quant_min = 0; + int64_t quant_max = 255; + + TensorFactory tfo; + Tensor out = tfo.zeros({3, 2}); + // Channel 0: 4 / 0.5 + 100 = 108 + // Channel 1: 4 / 1.0 + 50 = 54 + // Channel 2: 4 / 2.0 + 25 = 27 + Tensor expected = tfo.make({3, 2}, {108, 108, 54, 54, 27, 27}); + quantize_per_channel_out( + input, scale, zero_point, 0, quant_min, quant_max, ScalarType::Byte, out); + + EXPECT_TENSOR_EQ(out, expected); +} + +TEST(OpQuantizeOutTest, QuantizePerChannel3D) { + TensorFactory tf_float; + TensorFactory tf_double; + TensorFactory tf_long; + + // Test 3D tensor with axis=1 (middle dimension) + Tensor input = tf_float.full({2, 3, 4}, 6); + Tensor scale = tf_double.make({3}, {0.5, 1.0, 1.5}); + Tensor zero_point = tf_long.make({3}, {10, 20, 30}); + int64_t quant_min = -128; + int64_t quant_max = 127; + + TensorFactory tfo; + Tensor out = tfo.zeros({2, 3, 4}); + // Channel 0: 6 / 0.5 + 10 = 22 + // Channel 1: 6 / 1.0 + 20 = 26 + // Channel 2: 6 / 1.5 + 30 = 34 + Tensor expected = tfo.make( + {2, 3, 4}, + { + 22, 22, 22, 22, // First batch, channel 0 + 26, 26, 26, 26, // First batch, channel 1 + 34, 34, 34, 34, // First batch, channel 2 + 22, 22, 22, 22, // Second batch, channel 0 + 26, 26, 26, 26, // Second batch, channel 1 + 34, 34, 34, 34 // Second batch, channel 2 + }); + quantize_per_channel_out( + input, scale, zero_point, 1, quant_min, quant_max, ScalarType::Char, out); + + EXPECT_TENSOR_EQ(out, expected); +} + +TEST(OpQuantizeOutTest, QuantizePerChannel4D) { + TensorFactory tf_float; + TensorFactory tf_double; + TensorFactory tf_long; + + // Test 4D tensor with axis=2 (typical conv weight layout: N,C,H,W) + Tensor input = tf_float.full({2, 2, 3, 2}, 8); + Tensor scale = tf_double.make({3}, {0.25, 0.5, 1.0}); + Tensor zero_point = tf_long.make({3}, {0, 10, 20}); + int64_t quant_min = -128; + int64_t quant_max = 127; + + TensorFactory tfo; + Tensor out = tfo.zeros({2, 2, 3, 2}); + // Channel 0: 8 / 0.25 + 0 = 32 + // Channel 1: 8 / 0.5 + 10 = 26 + // Channel 2: 8 / 1.0 + 20 = 28 + std::vector expected_data; + for (int n = 0; n < 2; n++) { + for (int c = 0; c < 2; c++) { + for (int h = 0; h < 3; h++) { + for (int w = 0; w < 2; w++) { + int8_t val = (h == 0) ? 32 : (h == 1) ? 26 : 28; + expected_data.push_back(val); + } + } + } + } + Tensor expected = tfo.make({2, 2, 3, 2}, expected_data); + quantize_per_channel_out( + input, scale, zero_point, 2, quant_min, quant_max, ScalarType::Char, out); + + EXPECT_TENSOR_EQ(out, expected); +} + +TEST(OpQuantizeOutTest, QuantizePerChannelNegativeAxis) { + TensorFactory tf_float; + TensorFactory tf_double; + TensorFactory tf_long; + + Tensor input = tf_float.full({2, 3}, 5); + Tensor scale = tf_double.make({3}, {0.5, 1.0, 2.0}); + Tensor zero_point = tf_long.make({3}, {0, 10, 20}); + int64_t quant_min = 0; + int64_t quant_max = 255; + + TensorFactory tfo; + Tensor out = tfo.zeros({2, 3}); + // Using axis=-1 should be equivalent to axis=1 for 2D tensor + // Channel 0: 5 / 0.5 + 0 = 10 + // Channel 1: 5 / 1.0 + 10 = 15 + // Channel 2: 5 / 2.0 + 20 = 22 (rounded from 22.5) + Tensor expected = tfo.make({2, 3}, {10, 15, 22, 10, 15, 22}); + quantize_per_channel_out( + input, + scale, + zero_point, + -1, + quant_min, + quant_max, + ScalarType::Byte, + out); + + EXPECT_TENSOR_EQ(out, expected); +} + +TEST(OpQuantizeOutTest, QuantizePerChannelSingleChannel) { + TensorFactory tf_float; + TensorFactory tf_double; + TensorFactory tf_long; + + Tensor input = tf_float.full({3, 1, 4}, 7); + Tensor scale = tf_double.make({1}, {0.5}); + Tensor zero_point = tf_long.make({1}, {128}); + int64_t quant_min = 0; + int64_t quant_max = 255; + + TensorFactory tfo; + Tensor out = tfo.zeros({3, 1, 4}); + // Single channel: 7 / 0.5 + 128 = 142 + Tensor expected = tfo.full({3, 1, 4}, 142); + quantize_per_channel_out( + input, scale, zero_point, 1, quant_min, quant_max, ScalarType::Byte, out); + + EXPECT_TENSOR_EQ(out, expected); +} + +TEST(OpQuantizeOutTest, QuantizePerChannelDifferentInputTypes) { + TensorFactory tf_double_input; + TensorFactory tf_double; + TensorFactory tf_long; + + Tensor input = tf_double_input.full({2, 2}, 3.14159); + Tensor scale = tf_double.make({2}, {0.01, 0.02}); + Tensor zero_point = tf_long.make({2}, {0, 100}); + int64_t quant_min = -128; + int64_t quant_max = 127; + + TensorFactory tfo; + Tensor out = tfo.zeros({2, 2}); + // Channel 0: 3.14159 / 0.01 + 0 = 314 -> clamped to 127 + // Channel 1: 3.14159 / 0.02 + 100 = 257 -> clamped to 127 + Tensor expected = tfo.make({2, 2}, {127, 127, 127, 127}); + quantize_per_channel_out( + input, scale, zero_point, 1, quant_min, quant_max, ScalarType::Char, out); + + EXPECT_TENSOR_EQ(out, expected); +} + +TEST(OpQuantizeOutTest, QuantizePerChannelDifferentOutputTypes) { + TensorFactory tf_float; + TensorFactory tf_double; + TensorFactory tf_long; + + Tensor input = tf_float.full({2, 2}, 10); + Tensor scale = tf_double.make({2}, {1.0, 2.0}); + Tensor zero_point = tf_long.make({2}, {1000, 2000}); + int64_t quant_min = -32768; + int64_t quant_max = 32767; + + // Test with 16-bit output + TensorFactory tfo; + Tensor out = tfo.zeros({2, 2}); + // Channel 0: 10 / 1.0 + 1000 = 1010 + // Channel 1: 10 / 2.0 + 2000 = 2005 + Tensor expected = tfo.make({2, 2}, {1010, 2005, 1010, 2005}); + quantize_per_channel_out( + input, + scale, + zero_point, + 1, + quant_min, + quant_max, + ScalarType::Short, + out); + + EXPECT_TENSOR_EQ(out, expected); +} + +TEST(OpQuantizeOutTest, QuantizePerChannelMixedValues) { + TensorFactory tf_float; + TensorFactory tf_double; + TensorFactory tf_long; + + // Test with different input values per position + Tensor input = tf_float.make({2, 3}, {1.0, 2.0, 3.0, 4.0, 5.0, 6.0}); + Tensor scale = tf_double.make({3}, {0.5, 1.0, 1.5}); + Tensor zero_point = tf_long.make({3}, {10, 20, 30}); + int64_t quant_min = 0; + int64_t quant_max = 255; + + TensorFactory tfo; + Tensor out = tfo.zeros({2, 3}); + // Row 0: [1.0/0.5+10, 2.0/1.0+20, 3.0/1.5+30] = [12, 22, 32] + // Row 1: [4.0/0.5+10, 5.0/1.0+20, 6.0/1.5+30] = [18, 25, 34] + Tensor expected = tfo.make({2, 3}, {12, 22, 32, 18, 25, 34}); + quantize_per_channel_out( + input, scale, zero_point, 1, quant_min, quant_max, ScalarType::Byte, out); + + EXPECT_TENSOR_EQ(out, expected); +} + +TEST(OpQuantizeOutTest, QuantizePerChannelClampingBehavior) { + TensorFactory tf_float; + TensorFactory tf_double; + TensorFactory tf_long; + + // Test values that will exceed quant_min/quant_max bounds + Tensor input = tf_float.make({1, 3}, {-100.0, 0.0, 100.0}); + Tensor scale = tf_double.make({3}, {1.0, 1.0, 1.0}); + Tensor zero_point = tf_long.make({3}, {0, 0, 0}); + int64_t quant_min = -10; + int64_t quant_max = 10; + + TensorFactory tfo; + Tensor out = tfo.zeros({1, 3}); + // Values: [-100, 0, 100] should be clamped to [-10, 0, 10] + Tensor expected = tfo.make({1, 3}, {-10, 0, 10}); + quantize_per_channel_out( + input, scale, zero_point, 1, quant_min, quant_max, ScalarType::Char, out); + + EXPECT_TENSOR_EQ(out, expected); +}