diff --git a/.github/workflows/pull.yml b/.github/workflows/pull.yml index a50f637e250..c15fadd102f 100644 --- a/.github/workflows/pull.yml +++ b/.github/workflows/pull.yml @@ -1010,6 +1010,7 @@ jobs: ./cmake-out/backends/vulkan/test/custom_ops/q4gsw_linear ./cmake-out/backends/vulkan/test/custom_ops/choose_qparams_per_row ./cmake-out/backends/vulkan/test/custom_ops/qdq8ta_conv2d_activations + ./cmake-out/backends/vulkan/test/custom_ops/q8ta_q8ta_q8to_add # "Classic" Operator tests PYTHON_EXECUTABLE=python bash backends/vulkan/test/scripts/test_op.sh --build diff --git a/backends/vulkan/_passes/replace_qdq.py b/backends/vulkan/_passes/replace_qdq.py index 3613c5bf53c..fcfcdfc4c18 100644 --- a/backends/vulkan/_passes/replace_qdq.py +++ b/backends/vulkan/_passes/replace_qdq.py @@ -30,6 +30,7 @@ def call(self, graph_module: torch.fx.GraphModule): if node.target in [ exir_ops.edge.et_vk.conv2d_q8ta_q8csw_q8to.default, exir_ops.edge.et_vk.conv2d_q8ta_q8csw_q8to_dw.default, + exir_ops.edge.et_vk.add_q8ta_q8ta_q8to.default, ]: # Replace quantize op feeding into conv2d (first argument is the quantized input) quantized_input_node = node.args[0] diff --git a/backends/vulkan/custom_ops_lib.py b/backends/vulkan/custom_ops_lib.py index 314c470e5db..6e5aa926d37 100644 --- a/backends/vulkan/custom_ops_lib.py +++ b/backends/vulkan/custom_ops_lib.py @@ -572,3 +572,45 @@ def dequantize_q8to_from_conv2d_impl( lib.define(f"{name}(Tensor input, float scale, int zero_point) -> Tensor") lib.impl(name, dequantize_q8to_from_conv2d_impl, "CompositeExplicitAutograd") dequantize_q8to_from_conv2d_op = getattr(getattr(torch.ops, namespace), name) + +######################## +## add_q8ta_q8ta_q8to ## +######################## + + +def add_q8ta_q8ta_q8to_impl( + input_a: torch.Tensor, + input_b: torch.Tensor, + input_a_scale: float, + input_a_zero_point: int, + input_b_scale: float, + input_b_zero_point: int, + output_scale: float, + output_zero_point: int, + alpha: float, +): + # Dequantize inputs to float + dequant_a = torch.ops.quantized_decomposed.dequantize_per_tensor( + input_a, input_a_scale, input_a_zero_point, -128, 127, input_a.dtype + ) + dequant_b = torch.ops.quantized_decomposed.dequantize_per_tensor( + input_b, input_b_scale, input_b_zero_point, -128, 127, input_b.dtype + ) + + # Perform addition with alpha scaling + result = dequant_a + alpha * dequant_b + + # Quantize the result back to int8 + quantized_result = torch.ops.quantized_decomposed.quantize_per_tensor( + result, output_scale, output_zero_point, -128, 127, torch.int8 + ) + + return quantized_result + + +name = "add_q8ta_q8ta_q8to" +lib.define( + f"{name}(Tensor input_a, Tensor input_b, float input_a_scale, int input_a_zero_point, float input_b_scale, int input_b_zero_point, float output_scale, int output_zero_point, float alpha) -> Tensor" +) +lib.impl(name, add_q8ta_q8ta_q8to_impl, "CompositeExplicitAutograd") +add_q8ta_q8ta_q8to_op = getattr(getattr(torch.ops, namespace), name) diff --git a/backends/vulkan/op_registry.py b/backends/vulkan/op_registry.py index 8d67a5275d7..a92b3b11f6f 100644 --- a/backends/vulkan/op_registry.py +++ b/backends/vulkan/op_registry.py @@ -523,6 +523,19 @@ def register_quantized_conv_op(): ) +@update_features( + [ + exir_ops.edge.et_vk.add_q8ta_q8ta_q8to.default, + ] +) +def register_quantized_binary_op(): + return OpFeatures( + inputs_storage=utils.PACKED_INT8_4W4C_BUFFER, + supports_resize=False, + supports_prepacking=True, + ) + + @update_features( [ exir_ops.edge.et_vk.quantize_q8ta_for_conv2d.default, diff --git a/backends/vulkan/patterns/TARGETS b/backends/vulkan/patterns/TARGETS index 791edf58984..285efe2b933 100644 --- a/backends/vulkan/patterns/TARGETS +++ b/backends/vulkan/patterns/TARGETS @@ -11,6 +11,7 @@ runtime.python_library( "rope.py", "quantized_linear.py", "quantized_convolution.py", + "quantized_binary.py", ], visibility = [ "//executorch/backends/...", diff --git a/backends/vulkan/patterns/__init__.py b/backends/vulkan/patterns/__init__.py index 8ffad98b3c3..e23dfc7629c 100644 --- a/backends/vulkan/patterns/__init__.py +++ b/backends/vulkan/patterns/__init__.py @@ -6,6 +6,8 @@ from typing import List +import executorch.backends.vulkan.patterns.quantized_binary # noqa + import executorch.backends.vulkan.patterns.quantized_convolution # noqa import executorch.backends.vulkan.patterns.quantized_linear # noqa diff --git a/backends/vulkan/patterns/quantized_binary.py b/backends/vulkan/patterns/quantized_binary.py new file mode 100644 index 00000000000..da4985b931d --- /dev/null +++ b/backends/vulkan/patterns/quantized_binary.py @@ -0,0 +1,161 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from typing import Optional + +import executorch.backends.vulkan.utils as utils + +import torch + +from executorch.backends.vulkan.patterns.pattern_registry import ( + PatternMatch, + register_pattern_detector, + register_pattern_replacement, +) + +from executorch.exir import ExportedProgram +from executorch.exir.dialects._ops import ops as exir_ops + + +class QuantizedBinaryMatch(PatternMatch): + def __init__(self, binary_node: torch.fx.Node) -> None: + self.anchor_node = binary_node + self.match_found = False + self.all_nodes = [self.anchor_node] + + # Extract alpha parameter if it exists (for add operations) + self.alpha = 1.0 + if len(binary_node.args) > 2 and binary_node.args[2] is not None: + # Alpha is typically a scalar value + if isinstance(binary_node.args[2], (int, float)): + self.alpha = binary_node.args[2] + + # Identify input nodes - both should be dequantize nodes for static quantization + if len(binary_node.args) < 2: + return + + input_a_node = binary_node.args[0] + assert isinstance(input_a_node, torch.fx.Node) + input_b_node = binary_node.args[1] + assert isinstance(input_b_node, torch.fx.Node) + + # Both arguments must be dequant nodes for static quantization + if not utils.is_dequant_node(input_a_node) or not utils.is_dequant_node( + input_b_node + ): + return + + self.dequantize_input_a_node = input_a_node + self.dequantize_input_b_node = input_b_node + + # Extract quantization parameters for input A + self.quantize_input_a_node = self.dequantize_input_a_node.args[0] + self.input_a_scales_node = self.dequantize_input_a_node.args[1] + self.input_a_zeros_node = self.dequantize_input_a_node.args[2] + + # Extract quantization parameters for input B + self.quantize_input_b_node = self.dequantize_input_b_node.args[0] + self.input_b_scales_node = self.dequantize_input_b_node.args[1] + self.input_b_zeros_node = self.dequantize_input_b_node.args[2] + + self.all_nodes.extend( + [self.dequantize_input_a_node, self.dequantize_input_b_node] + ) + + # Identify output node + self.output_node = self.anchor_node + + # The binary operation output must have only one user; it will be either a relu node + # or a quantize node. + if len(self.output_node.users) != 1: + return + + cur_node = list(self.output_node.users)[0] + self.relu_node = None + if cur_node.target == exir_ops.edge.aten.relu.default: + self.relu_node = cur_node + self.all_nodes.append(self.relu_node) + # If there's a relu, get its user (should be the quantize node) + if len(cur_node.users) != 1: + return + cur_node = list(cur_node.users)[0] + + if not utils.is_quant_node(cur_node): + return + + self.quantize_output_node = cur_node + self.output_scales_node = self.quantize_output_node.args[1] + self.output_zeros_node = self.quantize_output_node.args[2] + + self.all_nodes.append(self.quantize_output_node) + + self.match_found = True + + +# Define the binary operation anchor nodes that we support +binary_anchor_nodes = { + exir_ops.edge.aten.add.Tensor, + exir_ops.edge.aten.add_.Tensor, +} + + +@register_pattern_detector("quantized_binary") +def find_quantized_binary_patterns( + node: torch.fx.Node, +) -> Optional[QuantizedBinaryMatch]: + if node.target not in binary_anchor_nodes: + return None + + matched_pattern = QuantizedBinaryMatch(node) + if matched_pattern.match_found: + return matched_pattern + + return None + + +## +## Pattern Replacement +## + + +@register_pattern_replacement("quantized_binary") +def make_add_q8ta_q8ta_q8to_custom_op( + ep: ExportedProgram, + graph_module: torch.fx.GraphModule, + match: QuantizedBinaryMatch, +): + # Determine the operation type based on the anchor node + op_target = None + if match.anchor_node.target in { + exir_ops.edge.aten.add.Tensor, + exir_ops.edge.aten.add_.Tensor, + }: + op_target = exir_ops.edge.et_vk.add_q8ta_q8ta_q8to.default + else: + # For future binary operations, add more mappings here + raise NotImplementedError( + f"Unsupported binary operation: {match.anchor_node.target}" + ) + + with graph_module.graph.inserting_before(match.output_node): + qbinary_node = graph_module.graph.create_node( + "call_function", + op_target, + args=( + match.quantize_input_a_node, + match.quantize_input_b_node, + match.input_a_scales_node, + match.input_a_zeros_node, + match.input_b_scales_node, + match.input_b_zeros_node, + match.output_scales_node, + match.output_zeros_node, + match.alpha, # Alpha parameter for scaling + ), + ) + + qbinary_node.meta["val"] = match.output_node.meta["val"] + match.quantize_output_node.replace_all_uses_with(qbinary_node) diff --git a/backends/vulkan/runtime/graph/ops/glsl/binary_q8ta_q8ta_q8to.glsl b/backends/vulkan/runtime/graph/ops/glsl/binary_q8ta_q8ta_q8to.glsl new file mode 100644 index 00000000000..8b69642d2e9 --- /dev/null +++ b/backends/vulkan/runtime/graph/ops/glsl/binary_q8ta_q8ta_q8to.glsl @@ -0,0 +1,78 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#version 450 core + +#define PRECISION ${PRECISION} + +#define NAME ${VARIANT_NAME} + +#define VEC4_T ${texel_load_type(DTYPE, "buffer")} +#define T ${texel_load_component_type(DTYPE, "buffer")} + +$if IO_STORAGE == "buffer": + #define PACKED_INT8_OUTPUT_BUFFER + #define PACKED_INT8_INPUT_BUFFER + +#define op(X, Y) ${OPERATOR} + +${define_required_extensions(DTYPE)} + +layout(std430) buffer; + +#extension GL_EXT_debug_printf : enable +#define DEBUG_MODE +#include "indexing.glslh" +#include "common.glslh" + +${layout_declare_tensor(B, "w", "t_packed_int8_out", "int", IO_STORAGE, is_scalar_array=False)} +${layout_declare_tensor(B, "r", "t_packed_int8_in_a", "int", IO_STORAGE, is_scalar_array=False)} +${layout_declare_tensor(B, "r", "t_packed_int8_in_b", "int", IO_STORAGE, is_scalar_array=False)} + +${layout_declare_ubo(B, "ivec4", "out_sizes")} + +layout(push_constant) uniform restrict Block { + float input_a_scale; + int input_a_zp; + float input_b_scale; + int input_b_zp; + float output_inv_scale; + int output_zp; +}; + +layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in; + +void main() { + const int tid = int(gl_GlobalInvocationID.x); + + const int W4 = div_up_4(out_sizes.x); + const int H = out_sizes.y; + const int C4 = div_up_4(out_sizes.z); + const int N = out_sizes.w; + + if (tid >= W4 * H * C4 * N) { + return; + } + + const ivec4 in_block_1 = t_packed_int8_in_a[tid]; + const ivec4 in_block_2 = t_packed_int8_in_b[tid]; + + ivec4 out_block = ivec4(pack_into_int32(ivec4(output_zp))); + + for (int row = 0; row < 4; row++) { + vec4 in_texel_1 = unpack_and_dequantize( + in_block_1[row], input_a_scale, input_a_zp); + vec4 in_texel_2 = unpack_and_dequantize( + in_block_2[row], input_b_scale, input_b_zp); + + vec4 out_texel = op(in_texel_1, in_texel_2); + out_block[row] = quantize_and_pack(out_texel, output_inv_scale, output_zp); + } + + t_packed_int8_out[tid] = out_block; +} diff --git a/backends/vulkan/runtime/graph/ops/glsl/binary_q8ta_q8ta_q8to.yaml b/backends/vulkan/runtime/graph/ops/glsl/binary_q8ta_q8ta_q8to.yaml new file mode 100644 index 00000000000..e19ed8839eb --- /dev/null +++ b/backends/vulkan/runtime/graph/ops/glsl/binary_q8ta_q8ta_q8to.yaml @@ -0,0 +1,19 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +binary_q8ta_q8ta_q8to: + parameter_names_with_default_values: + OPERATOR: X + Y + NDIM: 3 + DTYPE: float + PACKING: C_packed + IO_STORAGE: buffer + generate_variant_forall: + IO_STORAGE: + - VALUE: buffer + shader_variants: + - NAME: add_q8ta_q8ta_q8to + OPERATOR: X + Y diff --git a/backends/vulkan/runtime/graph/ops/glsl/common.glslh b/backends/vulkan/runtime/graph/ops/glsl/common.glslh index 95cdf70679b..eb0ee02c2b4 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/common.glslh +++ b/backends/vulkan/runtime/graph/ops/glsl/common.glslh @@ -72,6 +72,20 @@ int pack_into_int32(const ivec4 quant_vals) { return packed; } +vec4 unpack_and_dequantize( + const int packed_int8_vals, + const float scale, + const int zp) { + ivec4 unpacked = unpack_int8x4(packed_int8_vals); + return vec4(unpacked - zp) * scale; +} + +int quantize_and_pack(const vec4 vals, const float inv_scale, const int zp) { + ivec4 quantized = ivec4(round(vals * inv_scale) + zp); + quantized = clamp(quantized, -128, 127); + return pack_into_int32(quantized); +} + #ifdef DEBUG_MODE #extension GL_EXT_debug_printf : require diff --git a/backends/vulkan/runtime/graph/ops/impl/QuantizedBinary.cpp b/backends/vulkan/runtime/graph/ops/impl/QuantizedBinary.cpp new file mode 100644 index 00000000000..4b359f12700 --- /dev/null +++ b/backends/vulkan/runtime/graph/ops/impl/QuantizedBinary.cpp @@ -0,0 +1,210 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include + +#include +#include +#include +#include + +namespace vkcompute { + +// +// Shader dispatch utilities +// + +utils::uvec3 pick_q8ta_q8ta_q8to_binary_global_wg_size( + ComputeGraph* graph, + const vkapi::ShaderInfo& shader, + const std::vector& args, + const std::vector& resize_args) { + const ValueRef packed_int8_output = args.at(0).refs.at(0); + + const uint32_t W = graph->size_at(-1, packed_int8_output); + const uint32_t H = graph->size_at(-2, packed_int8_output); + const uint32_t C = graph->size_at(-3, packed_int8_output); + + const uint32_t W4 = utils::div_up_4(W); + const uint32_t C4 = utils::div_up_4(C); + + return {W4 * H * C4, 1, 1}; +} + +// +// Dispatch nodes +// + +void add_q8ta_q8ta_q8to_binary_node( + ComputeGraph& graph, + const ValueRef packed_int8_input_a, + const ValueRef packed_int8_input_b, + const ValueRef input_a_scale, + const ValueRef input_a_zp, + const ValueRef input_b_scale, + const ValueRef input_b_zp, + const ValueRef output_scale, + const ValueRef output_zp, + const ValueRef alpha, + const ValueRef packed_int8_output, + const std::string& op_name) { + float input_a_scale_val = graph.extract_scalar(input_a_scale); + int32_t input_a_zp_val = graph.extract_scalar(input_a_zp); + float input_b_scale_val = graph.extract_scalar(input_b_scale); + int32_t input_b_zp_val = graph.extract_scalar(input_b_zp); + + float output_inv_scale_val = 1.0f / graph.extract_scalar(output_scale); + int32_t output_zp_val = graph.extract_scalar(output_zp); + + float alpha_val = 1.0f; + // String is checked since some ops pass in an unused string argument in + // place of alpha + if (is_valid(alpha) && !graph.val_is_string(alpha)) { + alpha_val = graph.extract_scalar(alpha); + } + + std::string kernel_name = op_name + "_q8ta_q8ta_q8to"; + add_storage_type_suffix( + kernel_name, graph.storage_type_of(packed_int8_output)); + + vkapi::ParamsBindList param_buffers = {graph.sizes_ubo(packed_int8_output)}; + + std::vector push_constants = { + PushConstantDataInfo(&input_a_scale_val, sizeof(input_a_scale_val)), + PushConstantDataInfo(&input_a_zp_val, sizeof(input_a_zp_val)), + PushConstantDataInfo(&input_b_scale_val, sizeof(input_b_scale_val)), + PushConstantDataInfo(&input_b_zp_val, sizeof(input_b_zp_val)), + PushConstantDataInfo(&output_inv_scale_val, sizeof(output_inv_scale_val)), + PushConstantDataInfo(&output_zp_val, sizeof(output_zp_val)), + PushConstantDataInfo(&alpha_val, sizeof(alpha_val)), + }; + + graph.execute_nodes().emplace_back(new DynamicDispatchNode( + graph, + VK_KERNEL_FROM_STR(kernel_name), + pick_q8ta_q8ta_q8to_binary_global_wg_size, + default_pick_local_wg_size, + // Inputs and Outputs + {{packed_int8_output, vkapi::kWrite}, + {{packed_int8_input_a, packed_int8_input_b}, vkapi::kRead}}, + // Shader params buffers + param_buffers, + // Push Constants + push_constants, + // Specialization Constants + {}, + // Resize args + {}, + // Resizing Logic + nullptr)); +} + +// +// High level operator impl +// + +void add_q8ta_q8ta_q8to( + ComputeGraph& graph, + const std::vector& args) { + int32_t idx = 0; + const ValueRef packed_int8_input_a = args.at(idx++); + const ValueRef packed_int8_input_b = args.at(idx++); + const ValueRef input_a_scale = args.at(idx++); + const ValueRef input_a_zp = args.at(idx++); + const ValueRef input_b_scale = args.at(idx++); + const ValueRef input_b_zp = args.at(idx++); + const ValueRef output_scale = args.at(idx++); + const ValueRef output_zp = args.at(idx++); + const ValueRef alpha = args.at(idx++); + const ValueRef packed_int8_output = args.at(idx++); + + add_q8ta_q8ta_q8to_binary_node( + graph, + packed_int8_input_a, + packed_int8_input_b, + input_a_scale, + input_a_zp, + input_b_scale, + input_b_zp, + output_scale, + output_zp, + alpha, + packed_int8_output, + "add"); +} + +// +// Test operators +// + +void add_q8ta_q8ta_q8to_test( + ComputeGraph& graph, + const std::vector& args) { + int32_t idx = 0; + const ValueRef fp_input_a = args.at(idx++); + const ValueRef fp_input_b = args.at(idx++); + const ValueRef input_a_scale = args.at(idx++); + const ValueRef input_a_zp = args.at(idx++); + const ValueRef input_b_scale = args.at(idx++); + const ValueRef input_b_zp = args.at(idx++); + const ValueRef output_scale = args.at(idx++); + const ValueRef output_zp = args.at(idx++); + const ValueRef alpha = args.at(idx++); + const ValueRef fp_output = args.at(idx++); + + TmpTensor packed_int8_input_a( + &graph, + graph.sizes_of(fp_input_a), + vkapi::kInt8x4, + utils::kBuffer, + utils::kPackedInt8_4W4C); + + TmpTensor packed_int8_input_b( + &graph, + graph.sizes_of(fp_input_b), + vkapi::kInt8x4, + utils::kBuffer, + utils::kPackedInt8_4W4C); + + TmpTensor packed_int8_output( + &graph, + graph.sizes_of(fp_output), + vkapi::kInt8x4, + utils::kBuffer, + utils::kPackedInt8_4W4C); + + add_quantize_and_pack_q8ta_conv2d_input_node( + graph, fp_input_a, input_a_scale, input_a_zp, packed_int8_input_a); + + add_quantize_and_pack_q8ta_conv2d_input_node( + graph, fp_input_b, input_b_scale, input_b_zp, packed_int8_input_b); + + std::vector add_args = { + packed_int8_input_a, + packed_int8_input_b, + input_a_scale, + input_a_zp, + input_b_scale, + input_b_zp, + output_scale, + output_zp, + alpha, + packed_int8_output}; + + add_q8ta_q8ta_q8to(graph, add_args); + + add_unpack_and_dequantize_q8ta_conv2d_output_node( + graph, packed_int8_output, output_scale, output_zp, fp_output); +} + +REGISTER_OPERATORS { + VK_REGISTER_OP(et_vk.add_q8ta_q8ta_q8to.default, add_q8ta_q8ta_q8to); + VK_REGISTER_OP(et_vk.add_q8ta_q8ta_q8to.test, add_q8ta_q8ta_q8to_test); +} + +} // namespace vkcompute diff --git a/backends/vulkan/runtime/graph/ops/impl/QuantizedConvolution.cpp b/backends/vulkan/runtime/graph/ops/impl/QuantizedConvolution.cpp index 75bbb3892df..775e4534cfb 100644 --- a/backends/vulkan/runtime/graph/ops/impl/QuantizedConvolution.cpp +++ b/backends/vulkan/runtime/graph/ops/impl/QuantizedConvolution.cpp @@ -9,6 +9,7 @@ #include #include +#include #include #include #include diff --git a/backends/vulkan/runtime/graph/ops/impl/QuantizedConvolution.h b/backends/vulkan/runtime/graph/ops/impl/QuantizedConvolution.h new file mode 100644 index 00000000000..33474cee47b --- /dev/null +++ b/backends/vulkan/runtime/graph/ops/impl/QuantizedConvolution.h @@ -0,0 +1,42 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#pragma once + +#include + +namespace vkcompute { + +// +// Quantize and dequantize functions for conv2d that can be reused by other +// operations +// + +/** + * Add a dispatch node to quantize a floating-point input tensor to a packed + * int8 tensor for use in quantized operations. + */ +void add_quantize_and_pack_q8ta_conv2d_input_node( + ComputeGraph& graph, + const ValueRef fp_input, + const ValueRef input_scale, + const ValueRef input_zp, + const ValueRef packed_int8_input); + +/** + * Add a dispatch node to unpack and dequantize a packed int8 output tensor back + * to a floating-point tensor. + */ +void add_unpack_and_dequantize_q8ta_conv2d_output_node( + ComputeGraph& graph, + const ValueRef packed_int8_output, + const ValueRef output_scale, + const ValueRef output_zp, + const ValueRef fp_output); + +} // namespace vkcompute diff --git a/backends/vulkan/test/custom_ops/CMakeLists.txt b/backends/vulkan/test/custom_ops/CMakeLists.txt index 348eeded962..fc1d33391d4 100644 --- a/backends/vulkan/test/custom_ops/CMakeLists.txt +++ b/backends/vulkan/test/custom_ops/CMakeLists.txt @@ -98,4 +98,5 @@ if(TARGET vulkan_backend) add_operator_prototype(qdq8ta_conv2d_activations) add_operator_prototype(q8ta_q8csw_q8to_conv2d) add_operator_prototype(q8ta_q8csw_q8to_conv2d_dw) + add_operator_prototype(q8ta_q8ta_q8to_add) endif() diff --git a/backends/vulkan/test/custom_ops/q8ta_q8ta_q8to_add.cpp b/backends/vulkan/test/custom_ops/q8ta_q8ta_q8to_add.cpp new file mode 100644 index 00000000000..5799bc194c9 --- /dev/null +++ b/backends/vulkan/test/custom_ops/q8ta_q8ta_q8to_add.cpp @@ -0,0 +1,265 @@ +// Copyright (c) Meta Platforms, Inc. and affiliates. +// All rights reserved. +// +// This source code is licensed under the BSD-style license found in the +// LICENSE file in the root directory of this source tree. + +#include +#include +#include +#include +#include "utils.h" + +using namespace executorch::vulkan::prototyping; + +// Utility function to create a test case for quantized add operation +TestCase create_quantized_add_test_case( + const std::vector& sizes, + utils::StorageType storage_type, + vkapi::ScalarType input_dtype) { + TestCase test_case; + + // Create a descriptive name for the test case + std::string size_str = ""; + for (size_t i = 0; i < sizes.size(); ++i) { + size_str += std::to_string(sizes[i]); + if (i < sizes.size() - 1) + size_str += "x"; + } + + std::string storage_str = + (storage_type == utils::kTexture3D) ? "Texture3D" : "Buffer"; + std::string dtype_str = (input_dtype == vkapi::kFloat) ? "Float" : "Half"; + + std::string test_name = + "QuantizedAdd_" + size_str + "_" + storage_str + "_" + dtype_str; + test_case.set_name(test_name); + + // Set the operator name for the test case + test_case.set_operator_name("et_vk.add_q8ta_q8ta_q8to.test"); + + // Input tensor A (float/half) + ValueSpec input_a( + sizes, + input_dtype, + storage_type, + utils::kChannelsPacked, + DataGenType::RANDOM); + + // Input tensor B (float/half) + ValueSpec input_b( + sizes, + input_dtype, + storage_type, + utils::kChannelsPacked, + DataGenType::RANDOM); + + // Quantization parameters for input A + float input_a_scale_val = 0.007843; // 2/255 approximately + ValueSpec input_a_scale(input_a_scale_val); + + int32_t input_a_zero_point_val = 3; + ValueSpec input_a_zero_point(input_a_zero_point_val); + + // Quantization parameters for input B + float input_b_scale_val = 0.009412; // 2.4/255 approximately + ValueSpec input_b_scale(input_b_scale_val); + + int32_t input_b_zero_point_val = -2; + ValueSpec input_b_zero_point(input_b_zero_point_val); + + // Output quantization parameters + float output_scale_val = 0.015686; // 4/255 approximately + ValueSpec output_scale(output_scale_val); + + int32_t output_zero_point_val = 1; + ValueSpec output_zero_point(output_zero_point_val); + + // Alpha parameter + float alpha_val = 1.0f; + ValueSpec alpha(alpha_val); + + // Output tensor (float/half) + ValueSpec output( + sizes, + input_dtype, + storage_type, + utils::kChannelsPacked, + DataGenType::ZEROS); + + // Add all specs to test case for q8ta_q8ta_q8to add operation + test_case.add_input_spec(input_a); + test_case.add_input_spec(input_b); + test_case.add_input_spec(input_a_scale); + test_case.add_input_spec(input_a_zero_point); + test_case.add_input_spec(input_b_scale); + test_case.add_input_spec(input_b_zero_point); + test_case.add_input_spec(output_scale); + test_case.add_input_spec(output_zero_point); + test_case.add_input_spec(alpha); + + test_case.add_output_spec(output); + + test_case.set_abs_tolerance(output_scale_val + 1e-4f); + + return test_case; +} + +// Generate test cases for quantized add operation +std::vector generate_quantized_add_test_cases() { + std::vector test_cases; + + // Define different input size configurations + std::vector> size_configs = { + {3, 32, 32}, // Small square + {8, 64, 64}, // Medium square + {16, 16, 16}, // 3D cube + {8, 32, 16}, // 3D rectangular + {7, 7, 13}, // Irregular sizes + }; + + // Storage types to test + std::vector storage_types = {utils::kTexture3D}; + + // Data types to test + std::vector data_types = {vkapi::kFloat}; + + // Generate test cases for each combination + for (const auto& sizes : size_configs) { + for (const auto& storage_type : storage_types) { + for (const auto& data_type : data_types) { + test_cases.push_back( + create_quantized_add_test_case(sizes, storage_type, data_type)); + } + } + } + + return test_cases; +} + +// Reference implementation for quantized add operation +void add_q8ta_q8ta_q8to_reference_impl(TestCase& test_case) { + // Extract input specifications + int32_t idx = 0; + const ValueSpec& input_a_spec = test_case.inputs()[idx++]; + const ValueSpec& input_b_spec = test_case.inputs()[idx++]; + const ValueSpec& input_a_scale_spec = test_case.inputs()[idx++]; + const ValueSpec& input_a_zero_point_spec = test_case.inputs()[idx++]; + const ValueSpec& input_b_scale_spec = test_case.inputs()[idx++]; + const ValueSpec& input_b_zero_point_spec = test_case.inputs()[idx++]; + const ValueSpec& output_scale_spec = test_case.inputs()[idx++]; + const ValueSpec& output_zero_point_spec = test_case.inputs()[idx++]; + const ValueSpec& alpha_spec = test_case.inputs()[idx++]; + + // Extract output specification (mutable reference) + ValueSpec& output_spec = test_case.outputs()[0]; + + // Get tensor dimensions + auto input_sizes = input_a_spec.get_tensor_sizes(); + int64_t num_elements = input_a_spec.numel(); + + if (input_a_spec.dtype != vkapi::kFloat) { + throw std::invalid_argument("Unsupported dtype"); + } + + // Get raw data pointers + auto& input_a_data = input_a_spec.get_float_data(); + auto& input_b_data = input_b_spec.get_float_data(); + + const float input_a_scale = input_a_scale_spec.get_float_value(); + const int32_t input_a_zero_point = input_a_zero_point_spec.get_int_value(); + const float input_b_scale = input_b_scale_spec.get_float_value(); + const int32_t input_b_zero_point = input_b_zero_point_spec.get_int_value(); + const float output_scale = output_scale_spec.get_float_value(); + const int32_t output_zero_point = output_zero_point_spec.get_int_value(); + const float alpha = alpha_spec.get_float_value(); + + auto& ref_data = output_spec.get_ref_float_data(); + ref_data.resize(num_elements); + + // Perform quantized add operation + for (int64_t i = 0; i < num_elements; ++i) { + // Quantize input A to int8 + float quant_a_f = + std::round(input_a_data[i] / input_a_scale) + input_a_zero_point; + quant_a_f = std::min(std::max(quant_a_f, -128.0f), 127.0f); + int8_t quantized_a = static_cast(quant_a_f); + + // Quantize input B to int8 + float quant_b_f = + std::round(input_b_data[i] / input_b_scale) + input_b_zero_point; + quant_b_f = std::min(std::max(quant_b_f, -128.0f), 127.0f); + int8_t quantized_b = static_cast(quant_b_f); + + // Dequantize both inputs to a common scale for addition + float dequant_a = + (static_cast(quantized_a) - input_a_zero_point) * input_a_scale; + float dequant_b = + (static_cast(quantized_b) - input_b_zero_point) * input_b_scale; + + // Perform addition in float space with alpha + float float_result = dequant_a + alpha * dequant_b; + + // Quantize the result to int8 + float quant_output_f = + std::round(float_result / output_scale) + output_zero_point; + quant_output_f = std::min(std::max(quant_output_f, -128.0f), 127.0f); + int8_t quantized_output = static_cast(quant_output_f); + + // Dequantize back to float for comparison + float dequant_output = + (static_cast(quantized_output) - output_zero_point) * + output_scale; + + ref_data[i] = dequant_output; + } +} + +void reference_impl(TestCase& test_case) { + add_q8ta_q8ta_q8to_reference_impl(test_case); +} + +// Custom FLOP calculator for quantized add operation +int64_t quantized_add_flop_calculator(const TestCase& test_case) { + // Calculate total elements from the first input tensor + int64_t total_elements = 1; + if (!test_case.empty() && test_case.num_inputs() > 0 && + test_case.inputs()[0].is_tensor()) { + const auto& sizes = test_case.inputs()[0].get_tensor_sizes(); + for (int64_t size : sizes) { + total_elements *= size; + } + } + + // Quantized add operation includes: + // - 2 quantizations (float to int8) + // - 2 dequantizations (int8 to float) + // - 1 addition + // For simplicity, we count this as 1 FLOP per element (the addition) + return total_elements; +} + +int main(int argc, char* argv[]) { + set_debugging(false); + set_print_output(false); + set_print_latencies(false); + set_use_gpu_timestamps(true); + + print_performance_header(); + std::cout << "Quantized Add Operation (q8ta_q8ta_q8to) Prototyping Framework" + << std::endl; + print_separator(); + + ReferenceComputeFunc ref_fn = reference_impl; + + // Execute test cases using the new framework with custom FLOP calculator + auto results = execute_test_cases( + generate_quantized_add_test_cases, + quantized_add_flop_calculator, + "QuantizedAddQ8taQ8taQ8to", + 0, + 1, + ref_fn); + + return 0; +} diff --git a/backends/vulkan/test/custom_ops/targets.bzl b/backends/vulkan/test/custom_ops/targets.bzl index 959e013981c..4ef1cdd7fed 100644 --- a/backends/vulkan/test/custom_ops/targets.bzl +++ b/backends/vulkan/test/custom_ops/targets.bzl @@ -102,3 +102,4 @@ def define_common_targets(is_fbcode = False): define_custom_op_test_binary("qdq8ta_conv2d_activations") define_custom_op_test_binary("q8ta_q8csw_q8to_conv2d") define_custom_op_test_binary("q8ta_q8csw_q8to_conv2d_dw") + define_custom_op_test_binary("q8ta_q8ta_q8to_add")