diff --git a/backends/vulkan/_passes/replace_qdq.py b/backends/vulkan/_passes/replace_qdq.py index fcfcdfc4c18..2c5331eb213 100644 --- a/backends/vulkan/_passes/replace_qdq.py +++ b/backends/vulkan/_passes/replace_qdq.py @@ -32,24 +32,23 @@ def call(self, graph_module: torch.fx.GraphModule): exir_ops.edge.et_vk.conv2d_q8ta_q8csw_q8to_dw.default, exir_ops.edge.et_vk.add_q8ta_q8ta_q8to.default, ]: - # Replace quantize op feeding into conv2d (first argument is the quantized input) - quantized_input_node = node.args[0] - if isinstance( - quantized_input_node, torch.fx.Node - ) and utils.is_quant_node(quantized_input_node): - # Get the arguments from the original quantize node - input_tensor = quantized_input_node.args[0] - scale = quantized_input_node.args[1] - zero_point = quantized_input_node.args[2] + for quantized_input_node in node.args: + if isinstance( + quantized_input_node, torch.fx.Node + ) and utils.is_quant_node(quantized_input_node): + # Get the arguments from the original quantize node + input_tensor = quantized_input_node.args[0] + scale = quantized_input_node.args[1] + zero_point = quantized_input_node.args[2] - nodes_to_replace.append( - { - "old_node": quantized_input_node, - "new_target": exir_ops.edge.et_vk.quantize_q8ta_for_conv2d.default, - "args": (input_tensor, scale, zero_point), - "node_type": "quantize_input", - } - ) + nodes_to_replace.append( + { + "old_node": quantized_input_node, + "new_target": exir_ops.edge.et_vk.quantize_q8ta_for_conv2d.default, + "args": (input_tensor, scale, zero_point), + "node_type": "quantize_input", + } + ) # Find dequantize ops that consume the output of this conv2d for user in node.users: diff --git a/backends/vulkan/_passes/tag_memory_meta_pass.py b/backends/vulkan/_passes/tag_memory_meta_pass.py index 15449b98f6f..43796c043c8 100644 --- a/backends/vulkan/_passes/tag_memory_meta_pass.py +++ b/backends/vulkan/_passes/tag_memory_meta_pass.py @@ -368,7 +368,7 @@ def constrain_op_arg_repset(self, arg_i: int, op_repsets: utils.OpRepSets) -> No arg_repset = op_repsets.get_arg_repset(arg_i) if arg_repset.is_constrained(): - return arg_repset + return arg_node = op_repsets.op_node.args[arg_i] @@ -378,6 +378,20 @@ def constrain_op_arg_repset(self, arg_i: int, op_repsets: utils.OpRepSets) -> No arg_repset = self.trace_node_users_to_constrain_repset(arg_node, arg_repset) op_repsets.try_constrain_with_arg_repset(arg_i, arg_repset) + def constrain_op_out_repset(self, op_repsets: utils.OpRepSets) -> None: + """ + Similar to the `constrain_op_arg_repset` function, but for the output repset of + the operator. + """ + out_repset = op_repsets.get_out_repset(0) + if out_repset.is_constrained(): + return + + op_node = op_repsets.op_node + out_respset = self.trace_node_users_to_constrain_repset(op_node, out_repset) + + op_repsets.try_constrain_with_out_repset(out_respset) + def constrain_op_repsets(self, op_repsets: utils.OpRepSets) -> None: # For most ops, constraining the argument repsets will also contrain the output # repset due to OpRepSets maintaining synchronization rules. @@ -385,14 +399,12 @@ def constrain_op_repsets(self, op_repsets: utils.OpRepSets) -> None: if utils.is_tensor_arg_node(op_repsets.op_node.args[i]): self.constrain_op_arg_repset(i, op_repsets) - # TODO(ssjia): For most ops, inputs and outputs must be synchronized, so there - # is no need to constrain output repsets explicitly. Currently, the exceptions - # (i.e. choose qparams) already define constrined repsets for the output, so - # there is again no need to explicitly constrain the outputs. If an operator - # appears later on that does not sync input and output representations, and - # defines ambiguous repsets for the output tensor(s), then we will need to add - # additional logic to this function to constrain the output repsets separately - # from the input repsets. + # However, some operators do not sync input and output representations and also + # define ambiguous repsets for the output tensor(s). In those cases we will need + # to execute additional logic to constrain the output repsets separately from + # the input repsets. + if not op_repsets.sync_primary_io_repr and op_repsets.sync_outs_repr: + self.constrain_op_out_repset(op_repsets) def set_op_node_tensor_reprs( self, graph_module: torch.fx.GraphModule, op_node: torch.fx.Node diff --git a/backends/vulkan/op_registry.py b/backends/vulkan/op_registry.py index ef41060272c..403e747141f 100644 --- a/backends/vulkan/op_registry.py +++ b/backends/vulkan/op_registry.py @@ -636,7 +636,7 @@ def register_quantized_binary_op(): def register_quantize_for_conv2d_op(): return OpFeatures( inputs_storage=[ - utils.CHANNELS_PACKED_TEXTURE, + utils.CHANNELS_PACKED_TEXTURE_OR_CONTIGUOUS_BUFFER, ], outputs_storage=[ utils.PACKED_INT8_4W4C_BUFFER, @@ -656,7 +656,7 @@ def register_dequantize_for_conv2d_op(): utils.PACKED_INT8_4W4C_BUFFER, ], outputs_storage=[ - utils.CHANNELS_PACKED_TEXTURE, + utils.CHANNELS_PACKED_TEXTURE_OR_CONTIGUOUS_BUFFER, ], supports_resize=False, ) diff --git a/backends/vulkan/runtime/graph/ops/glsl/conv2d_fp_input_tile_load.glslh b/backends/vulkan/runtime/graph/ops/glsl/conv2d_fp_input_tile_load.glslh index be8a76421a5..a3934422e27 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/conv2d_fp_input_tile_load.glslh +++ b/backends/vulkan/runtime/graph/ops/glsl/conv2d_fp_input_tile_load.glslh @@ -14,7 +14,21 @@ #include "linear_fp_input_tile.glslh" VEC4_T load_fp_input_texel(const Conv2dTensorIndex tidx) { +#ifdef INPUT_BUFFER + VEC4_T texel = VEC4_T(0); + const int c_idx = mul_4(tidx.data.z); + const int c_stride = input_sizes.y * input_sizes.x; + + const int base_buf_i = c_idx * c_stride + tidx.data.y * input_sizes.x + tidx.data.x; + const int limit = min(input_sizes.z - c_idx, 4); + + for (int i = 0; i < limit; i++) { + texel[i] = t_fp_input[base_buf_i + i * c_stride]; + } + return texel; +#else return texelFetch(t_fp_input, tidx.data, 0); +#endif } void load_fp_input_tile( @@ -23,7 +37,9 @@ void load_fp_input_tile( #if TILE_M == 4 && TILE_K4 == 1 Conv2dTensorIndex load_tidx = block_idx_to_tensor_idx(block_idx); [[unroll]] for (int w = 0; w < TILE_M; w++) { - tile.data[w][0] = load_fp_input_texel(load_tidx); + if (load_tidx.data.x < input_sizes.x) { + tile.data[w][0] = load_fp_input_texel(load_tidx); + } load_tidx.data.x++; } #else diff --git a/backends/vulkan/runtime/graph/ops/glsl/quantize_and_pack_q8ta_conv2d_input.glsl b/backends/vulkan/runtime/graph/ops/glsl/quantize_and_pack_q8ta_conv2d_input.glsl index d485523709b..dfa0b5a95bf 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/quantize_and_pack_q8ta_conv2d_input.glsl +++ b/backends/vulkan/runtime/graph/ops/glsl/quantize_and_pack_q8ta_conv2d_input.glsl @@ -31,7 +31,7 @@ layout(std430) buffer; #include "conv2d_common.glslh" ${layout_declare_tensor(B, "w", "t_packed_int8_input", "int", OUTPUT_STORAGE, is_scalar_array=False)} -${layout_declare_tensor(B, "r", "t_fp_input", DTYPE, INPUT_STORAGE, is_scalar_array=False)} +${layout_declare_tensor(B, "r", "t_fp_input", DTYPE, INPUT_STORAGE)} ${layout_declare_ubo(B, "ivec4", "input_sizes")} diff --git a/backends/vulkan/runtime/graph/ops/glsl/quantize_and_pack_q8ta_conv2d_input.yaml b/backends/vulkan/runtime/graph/ops/glsl/quantize_and_pack_q8ta_conv2d_input.yaml index 712d3156e2e..929567a2595 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/quantize_and_pack_q8ta_conv2d_input.yaml +++ b/backends/vulkan/runtime/graph/ops/glsl/quantize_and_pack_q8ta_conv2d_input.yaml @@ -15,6 +15,7 @@ quantize_and_pack_q8ta_conv2d_input: combos: - parameter_values: [texture3d, texture3d] - parameter_values: [buffer, texture3d] + - parameter_values: [buffer, buffer] DTYPE: - VALUE: float shader_variants: diff --git a/backends/vulkan/runtime/graph/ops/glsl/unpack_and_dequantize_q8ta_conv2d_output.glsl b/backends/vulkan/runtime/graph/ops/glsl/unpack_and_dequantize_q8ta_conv2d_output.glsl index 798366b523a..be0a39bac3c 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/unpack_and_dequantize_q8ta_conv2d_output.glsl +++ b/backends/vulkan/runtime/graph/ops/glsl/unpack_and_dequantize_q8ta_conv2d_output.glsl @@ -30,7 +30,7 @@ layout(std430) buffer; #include "conv2d_common.glslh" -${layout_declare_tensor(B, "w", "t_fp_output", DTYPE, OUTPUT_STORAGE, is_scalar_array=False)} +${layout_declare_tensor(B, "w", "t_fp_output", DTYPE, OUTPUT_STORAGE)} ${layout_declare_tensor(B, "r", "t_packed_int8_output", "int", INPUT_STORAGE, is_scalar_array=False)} ${layout_declare_ubo(B, "ivec4", "output_sizes")} @@ -84,7 +84,19 @@ void unpack_and_dequantize( void store_fp_output_texel( const Conv2dTensorIndex tidx, const VEC4_T out_texel) { +#ifdef OUTPUT_BUFFER + const int c_idx = mul_4(tidx.data.z); + const int c_stride = output_sizes.y * output_sizes.x; + + const int base_buf_i = c_idx * c_stride + tidx.data.y * output_sizes.x + tidx.data.x; + const int limit = min(output_sizes.z - c_idx, 4); + + for (int i = 0; i < limit; ++i) { + t_fp_output[base_buf_i + i * c_stride] = out_texel[i]; + } +#else imageStore(t_fp_output, tidx.data, out_texel); +#endif } void store_fp_tile( @@ -92,7 +104,9 @@ void store_fp_tile( const Conv2dBlockIndex block_idx) { Conv2dTensorIndex store_tidx = block_idx_to_tensor_idx(block_idx); [[unroll]] for (int w = 0; w < 4; w++) { - store_fp_output_texel(store_tidx, block.data[w][0]); + if (store_tidx.data.x < output_sizes.x) { + store_fp_output_texel(store_tidx, block.data[w][0]); + } store_tidx.data.x++; } } diff --git a/backends/vulkan/runtime/graph/ops/glsl/unpack_and_dequantize_q8ta_conv2d_output.yaml b/backends/vulkan/runtime/graph/ops/glsl/unpack_and_dequantize_q8ta_conv2d_output.yaml index 24b253da343..ff1da144027 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/unpack_and_dequantize_q8ta_conv2d_output.yaml +++ b/backends/vulkan/runtime/graph/ops/glsl/unpack_and_dequantize_q8ta_conv2d_output.yaml @@ -15,6 +15,7 @@ unpack_and_dequantize_q8ta_conv2d_output: combos: - parameter_values: [texture3d, texture3d] - parameter_values: [texture3d, buffer] + - parameter_values: [buffer, buffer] DTYPE: - VALUE: float shader_variants: diff --git a/backends/vulkan/test/custom_ops/q8ta_q8csw_q8to_conv2d.cpp b/backends/vulkan/test/custom_ops/q8ta_q8csw_q8to_conv2d.cpp index 8762fe4c0d1..bbd4af7579c 100644 --- a/backends/vulkan/test/custom_ops/q8ta_q8csw_q8to_conv2d.cpp +++ b/backends/vulkan/test/custom_ops/q8ta_q8csw_q8to_conv2d.cpp @@ -47,11 +47,15 @@ TestCase create_test_case_from_config( std::vector input_size = { 1, config.channels.in, config.input_size.h, config.input_size.w}; + utils::GPUMemoryLayout io_memory_layout = storage_type == utils::kBuffer + ? utils::kWidthPacked + : utils::kChannelsPacked; + ValueSpec input_tensor( input_size, input_dtype, storage_type, - utils::kChannelsPacked, + io_memory_layout, DataGenType::RANDOM); if (debugging()) { @@ -139,7 +143,7 @@ TestCase create_test_case_from_config( {1, config.channels.out, H_out, W_out}, input_dtype, storage_type, - utils::kChannelsPacked, + io_memory_layout, DataGenType::ZEROS); // Add all specs to test case for q8ta_q8csw_q8to operation @@ -182,7 +186,8 @@ std::vector generate_quantized_conv2d_easy_cases() { config.op_name = "conv2d_q8ta_q8csw_q8to"; // Test with both storage types and data types for completeness - std::vector storage_types = {utils::kTexture3D}; + std::vector storage_types = { + utils::kTexture3D, utils::kBuffer}; std::vector float_types = {vkapi::kFloat}; // Generate test cases for each combination @@ -341,7 +346,8 @@ std::vector generate_quantized_conv2d_test_cases() { 4}}; // Test with different storage types and data types - std::vector storage_types = {utils::kTexture3D}; + std::vector storage_types = { + utils::kTexture3D, utils::kBuffer}; // Generate test cases for each combination for (auto& config : configs) { @@ -621,7 +627,7 @@ int main(int argc, char* argv[]) { quantized_conv2d_flop_calculator, "QuantizedConv2dQ8ToQ8To", 0, - 10, + 1, ref_fn); return 0; diff --git a/backends/vulkan/test/custom_ops/q8ta_q8ta_q8to_add.cpp b/backends/vulkan/test/custom_ops/q8ta_q8ta_q8to_add.cpp index 5799bc194c9..eb8e6908060 100644 --- a/backends/vulkan/test/custom_ops/q8ta_q8ta_q8to_add.cpp +++ b/backends/vulkan/test/custom_ops/q8ta_q8ta_q8to_add.cpp @@ -38,21 +38,17 @@ TestCase create_quantized_add_test_case( // Set the operator name for the test case test_case.set_operator_name("et_vk.add_q8ta_q8ta_q8to.test"); + utils::GPUMemoryLayout io_memory_layout = storage_type == utils::kBuffer + ? utils::kWidthPacked + : utils::kChannelsPacked; + // Input tensor A (float/half) ValueSpec input_a( - sizes, - input_dtype, - storage_type, - utils::kChannelsPacked, - DataGenType::RANDOM); + sizes, input_dtype, storage_type, io_memory_layout, DataGenType::RANDOM); // Input tensor B (float/half) ValueSpec input_b( - sizes, - input_dtype, - storage_type, - utils::kChannelsPacked, - DataGenType::RANDOM); + sizes, input_dtype, storage_type, io_memory_layout, DataGenType::RANDOM); // Quantization parameters for input A float input_a_scale_val = 0.007843; // 2/255 approximately @@ -81,11 +77,7 @@ TestCase create_quantized_add_test_case( // Output tensor (float/half) ValueSpec output( - sizes, - input_dtype, - storage_type, - utils::kChannelsPacked, - DataGenType::ZEROS); + sizes, input_dtype, storage_type, io_memory_layout, DataGenType::ZEROS); // Add all specs to test case for q8ta_q8ta_q8to add operation test_case.add_input_spec(input_a); @@ -119,7 +111,8 @@ std::vector generate_quantized_add_test_cases() { }; // Storage types to test - std::vector storage_types = {utils::kTexture3D}; + std::vector storage_types = { + utils::kTexture3D, utils::kBuffer}; // Data types to test std::vector data_types = {vkapi::kFloat}; diff --git a/backends/vulkan/utils.py b/backends/vulkan/utils.py index 9c527cbc36a..fca8173ffb7 100644 --- a/backends/vulkan/utils.py +++ b/backends/vulkan/utils.py @@ -772,6 +772,14 @@ def make_filtered_tensor_repset( HEIGHT_PACKED_TEXTURE = TensorRepSet(set(), {VkMemoryLayout.TENSOR_HEIGHT_PACKED}) CHANNELS_PACKED_TEXTURE = TensorRepSet(set(), {VkMemoryLayout.TENSOR_CHANNELS_PACKED}) +CHANNELS_PACKED_ANY = TensorRepSet( + {VkMemoryLayout.TENSOR_CHANNELS_PACKED}, {VkMemoryLayout.TENSOR_CHANNELS_PACKED} +) + +CHANNELS_PACKED_TEXTURE_OR_CONTIGUOUS_BUFFER = TensorRepSet( + {VkMemoryLayout.TENSOR_WIDTH_PACKED}, {VkMemoryLayout.TENSOR_CHANNELS_PACKED} +) + ANY_TEXTURE = TensorRepSet(set(), all_memory_layouts) ANY_BUFFER = TensorRepSet(all_memory_layouts, set()) @@ -1098,6 +1106,25 @@ def try_constrain_with_arg_repset( self.assert_sync_contraints() return True + def try_constrain_with_out_repset(self, repset: TensorRepSet): + # Skip for operators that must synchronize the input and output representations + # or operators that have more than one output repset + if self.sync_primary_io_repr or len(self.outs_repset_list) > 1: + return False + + out_current_repset = self.outs_repset_list[0] + + if out_current_repset == repset: + return False + + if not out_current_repset.any_in_common(repset): + return False + + self.outs_repset_list[0] = out_current_repset.make_intersect(repset) + + self.assert_sync_contraints() + return True + def pick_representations(self) -> Tuple[TensorReprList, TensorReprList]: """ For each tensor participating in the op, pick a representation for it among the