diff --git a/backends/vulkan/runtime/graph/ops/glsl/bitw8_image_to_nchw_nobitw8buffer.yaml b/backends/vulkan/runtime/graph/ops/glsl/bitw8_image_to_nchw_nobitw8buffer.yaml deleted file mode 100644 index e15e27addad..00000000000 --- a/backends/vulkan/runtime/graph/ops/glsl/bitw8_image_to_nchw_nobitw8buffer.yaml +++ /dev/null @@ -1,19 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the BSD-style license found in the -# LICENSE file in the root directory of this source tree. - -bitw8_image_to_nchw_nobitw8buffer: - parameter_names_with_default_values: - STORAGE: texture3d - DTYPE: int8 - generate_variant_forall: - DTYPE: - - VALUE: int8 - - VALUE: uint8 - STORAGE: - - VALUE: texture2d - - VALUE: texture3d - shader_variants: - - NAME: bitw8_image_to_nchw_nobitw8buffer diff --git a/backends/vulkan/runtime/graph/ops/glsl/bitw8_image_to_nchw_nobitw8buffer.glsl b/backends/vulkan/runtime/graph/ops/glsl/int8_image_to_nchw_noint8.glsl similarity index 91% rename from backends/vulkan/runtime/graph/ops/glsl/bitw8_image_to_nchw_nobitw8buffer.glsl rename to backends/vulkan/runtime/graph/ops/glsl/int8_image_to_nchw_noint8.glsl index 4fd6e2f14aa..f7133dd0452 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/bitw8_image_to_nchw_nobitw8buffer.glsl +++ b/backends/vulkan/runtime/graph/ops/glsl/int8_image_to_nchw_noint8.glsl @@ -10,8 +10,6 @@ #define PRECISION ${PRECISION} -${define_active_storage_type(STORAGE)} - #include "indexing_utils.h" layout(std430) buffer; @@ -19,7 +17,7 @@ layout(std430) buffer; #extension GL_EXT_control_flow_attributes : require ${layout_declare_buffer(B, "w", "nchw_out", "int")} -${layout_declare_tensor(B, "r", "t_in", DTYPE, STORAGE)} +${layout_declare_tensor(B, "r", "t_in", "int8", "texture3d")} ${layout_declare_ubo(B, "ivec4", "tensor_sizes")} ${layout_declare_ubo(B, "ivec4", "axis_map")} ${layout_declare_ubo(B, "int", "out_numel")} @@ -46,7 +44,7 @@ void main() { const ivec4 tidx = nchwi_to_tidx(in_buf_idx, tensor_sizes); const ivec4 texture_pos = to_texture_elem_pos( tidx, tensor_sizes, packed_dim); - values[i] = ivec4(load_texel(t_in, texture_pos.xyz))[texture_pos.w]; + values[i] = load_texel(t_in, texture_pos.xyz)[texture_pos.w]; in_buf_idx++; } diff --git a/backends/vulkan/runtime/graph/ops/glsl/nchw_to_bitw8_image_nobitw8buffer.yaml b/backends/vulkan/runtime/graph/ops/glsl/nchw_to_bitw8_image_nobitw8buffer.yaml deleted file mode 100644 index 7fe3849fd5c..00000000000 --- a/backends/vulkan/runtime/graph/ops/glsl/nchw_to_bitw8_image_nobitw8buffer.yaml +++ /dev/null @@ -1,19 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the BSD-style license found in the -# LICENSE file in the root directory of this source tree. - -nchw_to_bitw8_image_nobitw8buffer: - parameter_names_with_default_values: - STORAGE: texture3d - DTYPE: int8 - generate_variant_forall: - DTYPE: - - VALUE: int8 - - VALUE: uint8 - STORAGE: - - VALUE: texture2d - - VALUE: texture3d - shader_variants: - - NAME: nchw_to_bitw8_image_nobitw8buffer diff --git a/backends/vulkan/runtime/graph/ops/glsl/nchw_to_bitw8_image_nobitw8buffer.glsl b/backends/vulkan/runtime/graph/ops/glsl/nchw_to_int8_image_noint8.glsl similarity index 89% rename from backends/vulkan/runtime/graph/ops/glsl/nchw_to_bitw8_image_nobitw8buffer.glsl rename to backends/vulkan/runtime/graph/ops/glsl/nchw_to_int8_image_noint8.glsl index 8a3ef68528f..f3a3370f3ba 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/nchw_to_bitw8_image_nobitw8buffer.glsl +++ b/backends/vulkan/runtime/graph/ops/glsl/nchw_to_int8_image_noint8.glsl @@ -10,17 +10,13 @@ #define PRECISION ${PRECISION} -#define VEC4_T ${texel_load_type(DTYPE, STORAGE)} - -${define_active_storage_type(STORAGE)} - #include "indexing_utils.h" layout(std430) buffer; #extension GL_EXT_control_flow_attributes : require -${layout_declare_tensor(B, "w", "t_out", DTYPE, STORAGE)} +${layout_declare_tensor(B, "w", "t_out", "int8", "texture3d")} ${layout_declare_buffer(B, "r", "nchw_in", "int")} ${layout_declare_ubo(B, "ivec4", "sizes")} ${layout_declare_ubo(B, "ivec4", "axis_map")} @@ -75,5 +71,5 @@ void main() { return; } - write_texel(t_out, lpos_to_pos(lpos, axis_map), VEC4_T(read_texel(tidx))); + write_texel(t_out, lpos_to_pos(lpos, axis_map), read_texel(tidx)); } diff --git a/backends/vulkan/runtime/graph/ops/impl/Staging.cpp b/backends/vulkan/runtime/graph/ops/impl/Staging.cpp index f3966387042..ef6e8347df8 100644 --- a/backends/vulkan/runtime/graph/ops/impl/Staging.cpp +++ b/backends/vulkan/runtime/graph/ops/impl/Staging.cpp @@ -80,7 +80,7 @@ void add_tensor_to_staging_node( // output buffer. Therefore, the global work group size for this shader will // be the number of elements in the output buffer divided by 4, as opposed to // the extents of the input texture. - if (shader.kernel_name.starts_with("bitw8_image_to_nchw_nobitw8buffer")) { + if (shader.kernel_name == "int8_image_to_nchw_noint8") { uint32_t buffer_len = graph.get_staging(out_staging)->numel() / 4; global_wg_size = {buffer_len, 1, 1}; ubos.append({graph.numel_ubo(in_tensor)}); diff --git a/backends/vulkan/runtime/graph/ops/utils/StagingUtils.cpp b/backends/vulkan/runtime/graph/ops/utils/StagingUtils.cpp index 934fd03ab7f..8804bcf2ef6 100644 --- a/backends/vulkan/runtime/graph/ops/utils/StagingUtils.cpp +++ b/backends/vulkan/runtime/graph/ops/utils/StagingUtils.cpp @@ -15,23 +15,15 @@ namespace vkcompute { -bool is_bitw8(vkapi::ScalarType dtype) { - return dtype == vkapi::kByte || dtype == vkapi::kChar || - dtype == vkapi::kQInt8 || dtype == vkapi::kQUInt8; -} - vkapi::ShaderInfo get_nchw_to_tensor_shader( const api::vTensor& v_dst, const bool int8_buffer_enabled) { std::string kernel_name; kernel_name.reserve(kShaderNameReserve); - if (is_bitw8(v_dst.dtype()) && v_dst.storage_type() != utils::kBuffer && - !int8_buffer_enabled) { - kernel_name = "nchw_to_bitw8_image_nobitw8buffer"; - add_dtype_suffix(kernel_name, v_dst); - add_storage_type_suffix(kernel_name, v_dst); - return VK_KERNEL_FROM_STR(kernel_name); + if (v_dst.dtype() == vkapi::kChar && + v_dst.storage_type() == utils::kTexture3D && !int8_buffer_enabled) { + return VK_KERNEL(nchw_to_int8_image_noint8); } if (v_dst.storage_type() == utils::kBuffer) { @@ -53,12 +45,9 @@ vkapi::ShaderInfo get_tensor_to_nchw_shader( std::string kernel_name; kernel_name.reserve(kShaderNameReserve); - if (is_bitw8(v_src.dtype()) && v_src.storage_type() != utils::kBuffer && - !int8_buffer_enabled) { - kernel_name = "bitw8_image_to_nchw_nobitw8buffer"; - add_dtype_suffix(kernel_name, v_src); - add_storage_type_suffix(kernel_name, v_src); - return VK_KERNEL_FROM_STR(kernel_name); + if (v_src.dtype() == vkapi::kChar && + v_src.storage_type() == utils::kTexture3D && !int8_buffer_enabled) { + return VK_KERNEL(int8_image_to_nchw_noint8); } if (v_src.storage_type() == utils::kBuffer) { diff --git a/backends/vulkan/test/utils/test_utils.cpp b/backends/vulkan/test/utils/test_utils.cpp index 73e2f049a33..e4ada921226 100644 --- a/backends/vulkan/test/utils/test_utils.cpp +++ b/backends/vulkan/test/utils/test_utils.cpp @@ -111,20 +111,15 @@ void record_image_to_nchw_op( v_src.axis_map_ubo()); } -void record_bitw8_image_to_nchw_nobitw8buffer_op( +void record_int8_image_to_nchw_noint8_op( api::Context* const context, api::vTensor& v_src, api::StagingBuffer& dst_buffer) { vkapi::PipelineBarrier pipeline_barrier{}; uint32_t buffer_len = utils::safe_downcast(dst_buffer.numel() / 4); utils::uvec3 global_wg_size = {buffer_len, 1, 1}; - - std::string kernel_name = "bitw8_image_to_nchw_nobitw8buffer"; - add_dtype_suffix(kernel_name, v_src); - add_storage_type_suffix(kernel_name, v_src); - context->submit_compute_job( - VK_KERNEL_FROM_STR(kernel_name), + VK_KERNEL(int8_image_to_nchw_noint8), pipeline_barrier, global_wg_size, adaptive_work_group_size(global_wg_size), diff --git a/backends/vulkan/test/utils/test_utils.h b/backends/vulkan/test/utils/test_utils.h index f3ee2a717a5..d9d83a9620f 100644 --- a/backends/vulkan/test/utils/test_utils.h +++ b/backends/vulkan/test/utils/test_utils.h @@ -84,7 +84,7 @@ void record_image_to_nchw_op( vkcompute::api::vTensor& v_src, vkcompute::vkapi::VulkanBuffer& dst_buffer); -void record_bitw8_image_to_nchw_nobitw8buffer_op( +void record_int8_image_to_nchw_noint8_op( vkcompute::api::Context* const context, vkcompute::api::vTensor& v_src, vkcompute::api::StagingBuffer& dst_buffer); diff --git a/backends/vulkan/test/vulkan_compute_api_test.cpp b/backends/vulkan/test/vulkan_compute_api_test.cpp index f3ee1183612..c0840d2864a 100644 --- a/backends/vulkan/test/vulkan_compute_api_test.cpp +++ b/backends/vulkan/test/vulkan_compute_api_test.cpp @@ -2365,8 +2365,7 @@ void run_from_gpu_test( if (dtype == vkapi::kChar && !context()->adapter_ptr()->has_full_int8_buffers_support()) { - record_bitw8_image_to_nchw_nobitw8buffer_op( - context(), vten, staging_buffer); + record_int8_image_to_nchw_noint8_op(context(), vten, staging_buffer); } else { record_image_to_nchw_op(context(), vten, staging_buffer.buffer()); } @@ -2413,8 +2412,7 @@ void round_trip_test( // Copy data in and out of the tensor if (dtype == vkapi::kChar && !context()->adapter_ptr()->has_full_int8_buffers_support()) { - record_bitw8_image_to_nchw_nobitw8buffer_op( - context(), vten, staging_buffer_out); + record_int8_image_to_nchw_noint8_op(context(), vten, staging_buffer_out); } else { record_image_to_nchw_op(context(), vten, staging_buffer_out.buffer()); }