diff --git a/backends/vulkan/runtime/api/Context.cpp b/backends/vulkan/runtime/api/Context.cpp index 9517941f364..f425859935d 100644 --- a/backends/vulkan/runtime/api/Context.cpp +++ b/backends/vulkan/runtime/api/Context.cpp @@ -87,6 +87,27 @@ void Context::report_shader_dispatch_end() { } } +void Context::check_device_capabilities(const vkapi::ShaderInfo& shader) { + if (shader.requires_shader_int16) { + if (!adapter_p_->supports_int16_shader_types()) { + throw vkapi::ShaderNotSupportedError( + shader.kernel_name, vkapi::VulkanExtension::SHADER_INT16); + } + } + if (shader.requires_16bit_storage) { + if (!adapter_p_->supports_16bit_storage_buffers()) { + throw vkapi::ShaderNotSupportedError( + shader.kernel_name, vkapi::VulkanExtension::INT16_STORAGE); + } + } + if (shader.requires_8bit_storage) { + if (!adapter_p_->supports_8bit_storage_buffers()) { + throw vkapi::ShaderNotSupportedError( + shader.kernel_name, vkapi::VulkanExtension::INT8_STORAGE); + } + } +} + vkapi::DescriptorSet Context::get_descriptor_set( const vkapi::ShaderInfo& shader_descriptor, const utils::uvec3& local_workgroup_size, diff --git a/backends/vulkan/runtime/api/Context.h b/backends/vulkan/runtime/api/Context.h index 300fd3995dd..0c199c24cc4 100644 --- a/backends/vulkan/runtime/api/Context.h +++ b/backends/vulkan/runtime/api/Context.h @@ -185,6 +185,8 @@ class Context final { } } + void check_device_capabilities(const vkapi::ShaderInfo& shader); + vkapi::DescriptorSet get_descriptor_set( const vkapi::ShaderInfo&, const utils::uvec3&, diff --git a/backends/vulkan/runtime/gen_vulkan_spv.py b/backends/vulkan/runtime/gen_vulkan_spv.py index 7d004547a8e..7d3d2d52950 100644 --- a/backends/vulkan/runtime/gen_vulkan_spv.py +++ b/backends/vulkan/runtime/gen_vulkan_spv.py @@ -720,6 +720,10 @@ def maybe_replace_u16vecn(self, input_text: str) -> str: if "codegen-nosub" in input_text: return input_text + # Remove extension requirement so that generated ShaderInfo does not mark it + input_text = input_text.replace( + "#extension GL_EXT_shader_explicit_arithmetic_types_int16 : require", "" + ) input_text = input_text.replace("u16vec", "ivec") input_text = input_text.replace("uint16_t", "int") return input_text @@ -791,6 +795,9 @@ class ShaderInfo: weight_storage_type: str = "" bias_storage_type: str = "" register_for: Optional[Tuple[str, List[str]]] = None + requires_shader_int16_ext: bool = False + requires_16bit_storage_ext: bool = False + requires_8bit_storage_ext: bool = False def getName(filePath: str) -> str: @@ -858,6 +865,11 @@ def findRegisterFor(lineStr: str) -> Tuple[str, List[str]]: return (matches_list[0], matches_list[1:]) +def isExtensionRequireLine(lineStr: str) -> bool: + extension_require_id = r"^#extension ([A-Za-z0-9_]+)\s*:\s*require" + return re.search(extension_require_id, lineStr) is not None + + typeIdMapping = { r"image[123]D\b": "VK_DESCRIPTOR_TYPE_STORAGE_IMAGE", r"sampler[123]D\b": "VK_DESCRIPTOR_TYPE_COMBINED_IMAGE_SAMPLER", @@ -889,6 +901,13 @@ def getShaderInfo(srcFilePath: str) -> ShaderInfo: shader_info.bias_storage_type = getBiasStorageType(line) if isRegisterForLine(line): shader_info.register_for = findRegisterFor(line) + if isExtensionRequireLine(line): + if "GL_EXT_shader_explicit_arithmetic_types_int16" in line: + shader_info.requires_shader_int16_ext = True + if "GL_EXT_shader_16bit_storage" in line: + shader_info.requires_16bit_storage_ext = True + if "GL_EXT_shader_8bit_storage" in line: + shader_info.requires_8bit_storage_ext = True return shader_info @@ -952,12 +971,18 @@ def generateShaderInfoStr(shader_info: ShaderInfo, name: str, sizeBytes: int) -> shader_info_layouts = "{{{}}}".format(",\n ".join(shader_info.layouts)) + def to_cpp_str(val: bool): + return "true" if val else "false" + shader_info_args = [ f'"{name}"', f"{name}_bin", str(sizeBytes), shader_info_layouts, tile_size, + to_cpp_str(shader_info.requires_shader_int16_ext), + to_cpp_str(shader_info.requires_16bit_storage_ext), + to_cpp_str(shader_info.requires_8bit_storage_ext), ] shader_info_str = textwrap.indent( diff --git a/backends/vulkan/runtime/graph/ops/DispatchNode.cpp b/backends/vulkan/runtime/graph/ops/DispatchNode.cpp index a163a0d7aea..63b8798f2c1 100644 --- a/backends/vulkan/runtime/graph/ops/DispatchNode.cpp +++ b/backends/vulkan/runtime/graph/ops/DispatchNode.cpp @@ -58,6 +58,8 @@ void DispatchNode::encode(ComputeGraph* graph) { api::Context* const context = graph->context(); vkapi::PipelineBarrier pipeline_barrier{}; + context->check_device_capabilities(shader_); + std::unique_lock cmd_lock = context->dispatch_lock(); std::array push_constants_data; diff --git a/backends/vulkan/runtime/graph/ops/glsl/conv2d_dw_output_tile.glsl b/backends/vulkan/runtime/graph/ops/glsl/conv2d_dw_output_tile.glsl index 20fb9374bec..4a8d7418691 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/conv2d_dw_output_tile.glsl +++ b/backends/vulkan/runtime/graph/ops/glsl/conv2d_dw_output_tile.glsl @@ -34,8 +34,6 @@ ${layout_declare_ubo(8, "float", "out_min", "float", "out_max")} layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in; -#extension GL_EXT_shader_explicit_arithmetic_types_int16 : require - /* * Computes a depthwise convolution. Each shader invocation calculates the * output at a single output location. diff --git a/backends/vulkan/runtime/vk_api/Adapter.cpp b/backends/vulkan/runtime/vk_api/Adapter.cpp index 5805d476a38..ec30650ba06 100644 --- a/backends/vulkan/runtime/vk_api/Adapter.cpp +++ b/backends/vulkan/runtime/vk_api/Adapter.cpp @@ -256,6 +256,9 @@ std::string Adapter::stringize() const { ss << " deviceType: " << device_type << std::endl; ss << " deviceName: " << properties.deviceName << std::endl; +#define PRINT_BOOL(value, name) \ + ss << " " << std::left << std::setw(36) << #name << value << std::endl; + #define PRINT_PROP(struct, name) \ ss << " " << std::left << std::setw(36) << #name << struct.name \ << std::endl; @@ -298,12 +301,13 @@ std::string Adapter::stringize() const { ss << " }" << std::endl; #endif /* VK_KHR_8bit_storage */ -#ifdef VK_KHR_shader_float16_int8 ss << " Shader 16bit and 8bit Features {" << std::endl; + PRINT_BOOL(physical_device_.supports_int16_shader_types, shaderInt16) +#ifdef VK_KHR_shader_float16_int8 PRINT_PROP(physical_device_.shader_float16_int8_types, shaderFloat16); PRINT_PROP(physical_device_.shader_float16_int8_types, shaderInt8); - ss << " }" << std::endl; #endif /* VK_KHR_shader_float16_int8 */ + ss << " }" << std::endl; const VkPhysicalDeviceMemoryProperties& mem_props = physical_device_.memory_properties; diff --git a/backends/vulkan/runtime/vk_api/Exception.cpp b/backends/vulkan/runtime/vk_api/Exception.cpp index e330c1c079d..d26fbd8cb22 100644 --- a/backends/vulkan/runtime/vk_api/Exception.cpp +++ b/backends/vulkan/runtime/vk_api/Exception.cpp @@ -77,5 +77,36 @@ Error::Error(SourceLocation source_location, const char* cond, std::string msg) what_ = oss.str(); } +// +// ShaderNotSupportedError +// + +std::ostream& operator<<(std::ostream& out, const VulkanExtension result) { + switch (result) { + case VulkanExtension::SHADER_INT16: + out << "shaderInt16"; + break; + case VulkanExtension::INT16_STORAGE: + out << "VK_KHR_16bit_storage"; + break; + case VulkanExtension::INT8_STORAGE: + out << "VK_KHR_8bit_storage"; + break; + } + return out; +} + +ShaderNotSupportedError::ShaderNotSupportedError( + std::string shader_name, + VulkanExtension extension) + : shader_name_(std::move(shader_name)), extension_{extension} { + std::ostringstream oss; + oss << "Shader " << shader_name_ << " "; + oss << "not compatible with device. "; + oss << "Missing support for extension or physical device feature: "; + oss << extension_; + what_ = oss.str(); +} + } // namespace vkapi } // namespace vkcompute diff --git a/backends/vulkan/runtime/vk_api/Exception.h b/backends/vulkan/runtime/vk_api/Exception.h index ec2f2956a88..a65afb1bcc5 100644 --- a/backends/vulkan/runtime/vk_api/Exception.h +++ b/backends/vulkan/runtime/vk_api/Exception.h @@ -78,5 +78,26 @@ class Error : public std::exception { } }; +enum class VulkanExtension : uint8_t { + SHADER_INT16, + INT16_STORAGE, + INT8_STORAGE, +}; + +class ShaderNotSupportedError : public std::exception { + public: + ShaderNotSupportedError(std::string shader_name, VulkanExtension extension); + + private: + std::string shader_name_; + VulkanExtension extension_; + std::string what_; + + public: + const char* what() const noexcept override { + return what_.c_str(); + } +}; + } // namespace vkapi } // namespace vkcompute diff --git a/backends/vulkan/runtime/vk_api/Shader.cpp b/backends/vulkan/runtime/vk_api/Shader.cpp index 29774e2f404..e560f37868e 100644 --- a/backends/vulkan/runtime/vk_api/Shader.cpp +++ b/backends/vulkan/runtime/vk_api/Shader.cpp @@ -28,14 +28,20 @@ ShaderInfo::ShaderInfo( const uint32_t* const spirv_bin, const uint32_t size, std::vector layout, - const utils::uvec3 tile_size) + const utils::uvec3 tile_size, + const bool requires_shader_int16_ext, + const bool requires_16bit_storage_ext, + const bool requires_8bit_storage_ext) : src_code{ spirv_bin, size, }, kernel_name{std::move(name)}, kernel_layout{std::move(layout)}, - out_tile_size(tile_size) { + out_tile_size(tile_size), + requires_shader_int16(requires_shader_int16_ext), + requires_16bit_storage(requires_16bit_storage_ext), + requires_8bit_storage(requires_8bit_storage_ext) { } bool operator==(const ShaderInfo& _1, const ShaderInfo& _2) { diff --git a/backends/vulkan/runtime/vk_api/Shader.h b/backends/vulkan/runtime/vk_api/Shader.h index 1e3b2a799f2..d9fec65febc 100644 --- a/backends/vulkan/runtime/vk_api/Shader.h +++ b/backends/vulkan/runtime/vk_api/Shader.h @@ -62,6 +62,9 @@ struct ShaderInfo final { // Shader Metadata utils::uvec3 out_tile_size{1u, 1u, 1u}; + bool requires_shader_int16 = false; + bool requires_16bit_storage = false; + bool requires_8bit_storage = false; explicit ShaderInfo(); @@ -70,7 +73,10 @@ struct ShaderInfo final { const uint32_t*, const uint32_t, std::vector, - const utils::uvec3 tile_size); + const utils::uvec3 tile_size, + const bool requires_shader_int16_ext, + const bool requires_16bit_storage_ext, + const bool requires_8bit_storage_ext); operator bool() const { return src_code.bin != nullptr; diff --git a/backends/vulkan/test/op_tests/targets.bzl b/backends/vulkan/test/op_tests/targets.bzl index ab55d5beeaf..d26f1a805c3 100644 --- a/backends/vulkan/test/op_tests/targets.bzl +++ b/backends/vulkan/test/op_tests/targets.bzl @@ -3,6 +3,44 @@ load("@fbsource//xplat/caffe2:pt_defs.bzl", "get_pt_ops_deps") load("@fbsource//xplat/caffe2:pt_ops.bzl", "pt_operator_library") load("@fbsource//xplat/executorch/build:runtime_wrapper.bzl", "runtime") +def define_test_targets(test_name, extra_deps = [], src_file = None, is_fbcode = False): + deps_list = [ + "//third-party/googletest:gtest_main", + "//executorch/backends/vulkan:vulkan_graph_runtime", + runtime.external_dep_location("libtorch"), + ] + extra_deps + + src_file_str = src_file if src_file else "{}.cpp".format(test_name) + + runtime.cxx_binary( + name = "{}_bin".format(test_name), + srcs = [ + src_file_str, + ], + compiler_flags = [ + "-Wno-unused-variable", + ], + define_static_target = False, + deps = deps_list, + ) + + runtime.cxx_test( + name = test_name, + srcs = [ + src_file_str, + ], + contacts = ["oncall+ai_infra_mobile_platform@xmail.facebook.com"], + fbandroid_additional_loaded_sonames = [ + "torch-code-gen", + "vulkan_graph_runtime", + "vulkan_graph_runtime_shaderlib", + ], + platforms = [ANDROID], + use_instrumentation_test = True, + deps = deps_list, + ) + + def define_common_targets(is_fbcode = False): if is_fbcode: return @@ -82,19 +120,6 @@ def define_common_targets(is_fbcode = False): default_outs = ["."], ) - runtime.cxx_binary( - name = "compute_graph_op_tests_bin", - srcs = [ - ":generated_op_correctness_tests_cpp[op_tests.cpp]", - ], - define_static_target = False, - deps = [ - "//third-party/googletest:gtest_main", - "//executorch/backends/vulkan:vulkan_graph_runtime", - runtime.external_dep_location("libtorch"), - ], - ) - runtime.cxx_binary( name = "compute_graph_op_benchmarks_bin", srcs = [ @@ -111,135 +136,17 @@ def define_common_targets(is_fbcode = False): ], ) - runtime.cxx_test( - name = "compute_graph_op_tests", - srcs = [ - ":generated_op_correctness_tests_cpp[op_tests.cpp]", - ], - contacts = ["oncall+ai_infra_mobile_platform@xmail.facebook.com"], - fbandroid_additional_loaded_sonames = [ - "torch-code-gen", - "vulkan_graph_runtime", - "vulkan_graph_runtime_shaderlib", - ], - platforms = [ANDROID], - use_instrumentation_test = True, - deps = [ - "//third-party/googletest:gtest_main", - "//executorch/backends/vulkan:vulkan_graph_runtime", - runtime.external_dep_location("libtorch"), - ], + define_test_targets( + "compute_graph_op_tests", + src_file=":generated_op_correctness_tests_cpp[op_tests.cpp]" ) - runtime.cxx_binary( - name = "sdpa_test_bin", - srcs = [ - "sdpa_test.cpp", - ], - compiler_flags = [ - "-Wno-unused-variable", - ], - define_static_target = False, - deps = [ - "//third-party/googletest:gtest_main", - "//executorch/backends/vulkan:vulkan_graph_runtime", - "//executorch/extension/llm/custom_ops:custom_ops_aot_lib", - ], - ) - - runtime.cxx_test( - name = "sdpa_test", - srcs = [ - "sdpa_test.cpp", - ], - contacts = ["oncall+ai_infra_mobile_platform@xmail.facebook.com"], - fbandroid_additional_loaded_sonames = [ - "torch-code-gen", - "vulkan_graph_runtime", - "vulkan_graph_runtime_shaderlib", - ], - platforms = [ANDROID], - use_instrumentation_test = True, - deps = [ - "//third-party/googletest:gtest_main", - "//executorch/backends/vulkan:vulkan_graph_runtime", - "//executorch/extension/llm/custom_ops:custom_ops_aot_lib", - "//executorch/extension/tensor:tensor", - runtime.external_dep_location("libtorch"), - ], - ) - - runtime.cxx_binary( - name = "linear_weight_int4_test_bin", - srcs = [ - "linear_weight_int4_test.cpp", - ], - compiler_flags = [ - "-Wno-unused-variable", - ], - define_static_target = False, - deps = [ - "//third-party/googletest:gtest_main", - "//executorch/backends/vulkan:vulkan_graph_runtime", - runtime.external_dep_location("libtorch"), - ], - ) - - runtime.cxx_test( - name = "linear_weight_int4_test", - srcs = [ - "linear_weight_int4_test.cpp", - ], - contacts = ["oncall+ai_infra_mobile_platform@xmail.facebook.com"], - fbandroid_additional_loaded_sonames = [ - "torch-code-gen", - "vulkan_graph_runtime", - "vulkan_graph_runtime_shaderlib", - ], - platforms = [ANDROID], - use_instrumentation_test = True, - deps = [ - "//third-party/googletest:gtest_main", - "//executorch/backends/vulkan:vulkan_graph_runtime", + define_test_targets( + "sdpa_test", + extra_deps = [ "//executorch/extension/llm/custom_ops:custom_ops_aot_lib", "//executorch/extension/tensor:tensor", - runtime.external_dep_location("libtorch"), - ], - ) - - runtime.cxx_binary( - name = "rotary_embedding_test_bin", - srcs = [ - "rotary_embedding_test.cpp", - ], - compiler_flags = [ - "-Wno-unused-variable", - ], - define_static_target = False, - deps = [ - "//third-party/googletest:gtest_main", - "//executorch/backends/vulkan:vulkan_graph_runtime", - runtime.external_dep_location("libtorch"), - ], - ) - - runtime.cxx_test( - name = "rotary_embedding_test", - srcs = [ - "rotary_embedding_test.cpp", - ], - contacts = ["oncall+ai_infra_mobile_platform@xmail.facebook.com"], - fbandroid_additional_loaded_sonames = [ - "torch-code-gen", - "vulkan_graph_runtime", - "vulkan_graph_runtime_shaderlib", - ], - platforms = [ANDROID], - use_instrumentation_test = True, - deps = [ - "//third-party/googletest:gtest_main", - "//executorch/backends/vulkan:vulkan_graph_runtime", - "//executorch/extension/tensor:tensor", - runtime.external_dep_location("libtorch"), - ], + ] ) + define_test_targets("linear_weight_int4_test") + define_test_targets("rotary_embedding_test") diff --git a/backends/vulkan/test/op_tests/utils/gen_correctness_base.py b/backends/vulkan/test/op_tests/utils/gen_correctness_base.py index 3d9aa6aa80b..d7e38969452 100644 --- a/backends/vulkan/test/op_tests/utils/gen_correctness_base.py +++ b/backends/vulkan/test/op_tests/utils/gen_correctness_base.py @@ -45,8 +45,13 @@ class GeneratedOpsTest_{op_name} : public ::testing::Test {{ test_suite_template = """ TEST_P(GeneratedOpsTest_{op_name}, {case_name}) {{ {create_ref_data} +try {{ {create_and_check_out} }} +catch (const vkcompute::vkapi::ShaderNotSupportedError& e) {{ + GTEST_SKIP() << e.what(); +}} +}} """