From 343cf5d14aeab6e6c1f6647c91ffcf97514b6e68 Mon Sep 17 00:00:00 2001 From: ssjia Date: Thu, 11 Sep 2025 13:19:38 -0700 Subject: [PATCH] [ET-VK] Enable automatic dtype conversion when copying to/from staging buffer Pull Request resolved: https://github.com/pytorch/executorch/pull/14222 ## Context During export, Vulkan sometimes converts certain tensor dtypes. The most common case of this is that int64 and float64 are internally represented as int32 and float32 tensors. The primary reason for this is to reduce the number of dtype variants that need to be generated for each shader, and also due to the fact that 64-bit types are not guaranteed to be supported. However, this raises an issue if an int64 or float64 tensor is marked as an input/output tensor of the model. The source/destination ETensor will have a different dtype than the internal representation, meaning that the input/output bytes will be interpreted incorrectly. ## Changes This diff fixes this behaviour by introducing the concept of a "staging dtype". This allows the staging buffer of a tensor to have a different dtype than the underlying GPU buffer or texture. When copying to/from the GPU resource, the dtype can then be converted to the correct dtype expected by the client code. As a bonus, also add an optional setting to force fp16 to be used internally for fp32 tensors. This allows models to access half precision inference without needing to incur the cost of dtype conversion ops being inserted into the graph, or needing to manually convert inputs/outputs to half type. ghstack-source-id: 309155136 Differential Revision: [D82234180](https://our.internmc.facebook.com/intern/diff/D82234180/) --- backends/vulkan/runtime/VulkanBackend.cpp | 55 +++++++- backends/vulkan/runtime/api/Context.cpp | 12 ++ .../runtime/api/containers/StagingBuffer.h | 20 ++- backends/vulkan/runtime/gen_vulkan_spv.py | 53 ++++++-- .../vulkan/runtime/graph/ComputeGraph.cpp | 124 ++++++++++++++---- backends/vulkan/runtime/graph/ComputeGraph.h | 22 ++++ .../graph/ops/glsl/buffer_to_nchw.glsl | 6 +- .../graph/ops/glsl/buffer_to_nchw.yaml | 18 ++- .../runtime/graph/ops/glsl/image_to_nchw.glsl | 3 +- .../runtime/graph/ops/glsl/image_to_nchw.yaml | 18 ++- .../graph/ops/glsl/nchw_to_buffer.glsl | 5 +- .../graph/ops/glsl/nchw_to_buffer.yaml | 18 ++- .../runtime/graph/ops/glsl/nchw_to_image.glsl | 3 +- .../runtime/graph/ops/glsl/nchw_to_image.yaml | 17 ++- .../vulkan/runtime/graph/ops/impl/Clone.cpp | 2 + .../runtime/graph/ops/impl/Convolution.cpp | 3 +- .../vulkan/runtime/graph/ops/impl/Staging.cpp | 26 ++-- .../runtime/graph/ops/utils/StagingUtils.cpp | 6 + .../runtime/graph/ops/utils/StagingUtils.h | 3 + backends/vulkan/runtime/vk_api/Adapter.cpp | 6 + backends/vulkan/runtime/vk_api/Adapter.h | 8 ++ backends/vulkan/runtime/vk_api/Device.cpp | 9 ++ backends/vulkan/runtime/vk_api/Device.h | 4 +- backends/vulkan/runtime/vk_api/Exception.cpp | 6 + backends/vulkan/runtime/vk_api/Exception.h | 2 + backends/vulkan/runtime/vk_api/Shader.cpp | 8 +- backends/vulkan/runtime/vk_api/Shader.h | 6 +- backends/vulkan/serialization/schema.fbs | 4 + .../serialization/vulkan_graph_builder.py | 46 ++++++- .../serialization/vulkan_graph_schema.py | 2 + backends/vulkan/test/utils/test_utils.cpp | 4 + backends/vulkan/vulkan_preprocess.py | 5 + examples/vulkan/export.py | 22 ++++ 33 files changed, 447 insertions(+), 99 deletions(-) diff --git a/backends/vulkan/runtime/VulkanBackend.cpp b/backends/vulkan/runtime/VulkanBackend.cpp index 7b138072d50..67b646ae1a8 100644 --- a/backends/vulkan/runtime/VulkanBackend.cpp +++ b/backends/vulkan/runtime/VulkanBackend.cpp @@ -86,6 +86,32 @@ vkapi::ScalarType get_scalar_type(const vkgraph::VkDataType& vk_datatype) { return vkapi::kFloat; case vkgraph::VkDataType::FLOAT64: return vkapi::kDouble; + default: + VK_THROW("Invalid VkDataType type encountered!"); + } +} + +vkapi::ScalarType equivalent_scalar_type( + const executorch::runtime::etensor::ScalarType& et_datatype) { + switch (et_datatype) { + case executorch::runtime::etensor::ScalarType::Byte: + return vkapi::kByte; + case executorch::runtime::etensor::ScalarType::Char: + return vkapi::kChar; + case executorch::runtime::etensor::ScalarType::Int: + return vkapi::kInt; + case executorch::runtime::etensor::ScalarType::Long: + return vkapi::kLong; + case executorch::runtime::etensor::ScalarType::Half: + return vkapi::kHalf; + case executorch::runtime::etensor::ScalarType::Float: + return vkapi::kFloat; + case executorch::runtime::etensor::ScalarType::Double: + return vkapi::kDouble; + case executorch::runtime::etensor::ScalarType::Bool: + return vkapi::kBool; + default: + VK_THROW("Invalid etensor::ScalarType encountered!"); } } @@ -343,6 +369,15 @@ class GraphBuilder { } } + vkapi::ScalarType get_staging_scalar_type_of(const uint32_t fb_id) { + VkTensorPtr tensor_fb = + flatbuffer_->values()->Get(fb_id)->value_as_VkTensor(); + if (tensor_fb->staging_datatype() == vkgraph::VkDataType::UNSET) { + return get_scalar_type(tensor_fb->datatype()); + } + return get_scalar_type(tensor_fb->staging_datatype()); + } + void build_graph() { // Resize the mapping to the number of values in the flatbuffer resize(flatbuffer_->values()->size()); @@ -359,7 +394,8 @@ class GraphBuilder { for (const uint32_t fb_id : *flatbuffer_->input_ids()) { const ValueRef ref = get_fb_id_valueref(fb_id); if (compute_graph_->val_is_tensor(ref)) { - compute_graph_->set_input_tensor(ref); + compute_graph_->set_input_tensor( + ref, get_staging_scalar_type_of(fb_id)); } else { compute_graph_->set_val_as_input(ref); } @@ -384,7 +420,12 @@ class GraphBuilder { // values as well if the source graph returns parameter nodes. for (const uint32_t fb_id : *flatbuffer_->output_ids()) { const ValueRef ref = get_fb_id_valueref(fb_id); - compute_graph_->set_output_value(ref); + if (compute_graph_->val_is_tensor(ref)) { + compute_graph_->set_output_tensor( + ref, get_staging_scalar_type_of(fb_id)); + } else { + compute_graph_->set_output_value(ref); + } } if (compute_graph_->graphconfig().enable_querypool) { @@ -582,10 +623,11 @@ class VulkanBackend final : public ::executorch::runtime::BackendInterface { bool was_resized = maybe_resize_input(compute_graph, i, args[i]->toTensor()); should_propagate_resize = should_propagate_resize || was_resized; - compute_graph->copy_into_staging( + compute_graph->maybe_cast_and_copy_into_staging( compute_graph->inputs()[i].staging, args[i]->toTensor().const_data_ptr(), - args[i]->toTensor().numel()); + args[i]->toTensor().numel(), + equivalent_scalar_type(args[i]->toTensor().scalar_type())); } else if (compute_graph->val_is_symint(iref)) { VK_CHECK_COND( args[i]->isTensor(), @@ -617,10 +659,11 @@ class VulkanBackend final : public ::executorch::runtime::BackendInterface { maybe_resize_output(compute_graph, i, args[o]->toTensor()); // args holds inputs directly followed by outputs, so the i'th output // for compute_graph corresponds to the o'th arg - compute_graph->copy_from_staging( + compute_graph->maybe_cast_and_copy_from_staging( compute_graph->outputs()[i].staging, args[o]->toTensor().mutable_data_ptr(), - args[o]->toTensor().numel()); + args[o]->toTensor().numel(), + equivalent_scalar_type(args[o]->toTensor().scalar_type())); } // TensorRef values represent constant tensors which will not have been // modified by the graph execution. Therefore, if a constant tensor is diff --git a/backends/vulkan/runtime/api/Context.cpp b/backends/vulkan/runtime/api/Context.cpp index 8599cbfffb6..adb8409d28c 100644 --- a/backends/vulkan/runtime/api/Context.cpp +++ b/backends/vulkan/runtime/api/Context.cpp @@ -117,6 +117,18 @@ void Context::check_device_capabilities(const vkapi::ShaderInfo& shader) { shader.kernel_name, vkapi::VulkanExtension::INTEGER_DOT_PRODUCT); } } + if (shader.requires_shader_int64) { + if (!adapter_p_->supports_int64_shader_types()) { + throw vkapi::ShaderNotSupportedError( + shader.kernel_name, vkapi::VulkanExtension::SHADER_INT64); + } + } + if (shader.requires_shader_float64) { + if (!adapter_p_->supports_float64_shader_types()) { + throw vkapi::ShaderNotSupportedError( + shader.kernel_name, vkapi::VulkanExtension::SHADER_FLOAT64); + } + } } vkapi::DescriptorSet Context::get_descriptor_set( diff --git a/backends/vulkan/runtime/api/containers/StagingBuffer.h b/backends/vulkan/runtime/api/containers/StagingBuffer.h index 1e9f569fc4a..6d0e5a4a457 100644 --- a/backends/vulkan/runtime/api/containers/StagingBuffer.h +++ b/backends/vulkan/runtime/api/containers/StagingBuffer.h @@ -48,7 +48,7 @@ class StagingBuffer final { context_p_->register_buffer_cleanup(vulkan_buffer_); } - inline vkapi::ScalarType dtype() { + inline vkapi::ScalarType dtype() const { return dtype_; } @@ -81,6 +81,15 @@ class StagingBuffer final { VK_WHOLE_SIZE); } + template + void cast_and_copy_from(const SRC_T* src, const size_t numel) { + VK_CHECK_COND(numel <= this->numel()); + DST_T* dst = reinterpret_cast(data()); + for (size_t i = 0; i < numel; ++i) { + dst[i] = static_cast(src[i]); + } + } + inline void copy_to(void* dst, const size_t nbytes) { VK_CHECK_COND(nbytes <= this->nbytes()); vmaInvalidateAllocation( @@ -91,6 +100,15 @@ class StagingBuffer final { memcpy(dst, data(), nbytes); } + template + void cast_and_copy_to(DST_T* dst, const size_t numel) { + VK_CHECK_COND(numel <= this->numel()); + const SRC_T* src = reinterpret_cast(data()); + for (size_t i = 0; i < numel; ++i) { + dst[i] = static_cast(src[i]); + } + } + inline void set_staging_zeros() { memset(data(), 0, nbytes()); } diff --git a/backends/vulkan/runtime/gen_vulkan_spv.py b/backends/vulkan/runtime/gen_vulkan_spv.py index 3f2d616b428..6db2e01d7e2 100644 --- a/backends/vulkan/runtime/gen_vulkan_spv.py +++ b/backends/vulkan/runtime/gen_vulkan_spv.py @@ -670,7 +670,7 @@ def addSrcAndYamlFiles(self, src_dir_paths: List[str]) -> None: if len(file) > 1: self.template_yaml_files.append(file) - def generateVariantCombinations( + def generateVariantCombinations( # noqa: C901 self, iterated_params: Dict[str, Any], exclude_params: Optional[Set[str]] = None, @@ -679,7 +679,25 @@ def generateVariantCombinations( exclude_params = set() all_iterated_params = [] for param_name, value_list in iterated_params.items(): - if param_name not in exclude_params: + if re.match(r"^combination\d*$", param_name): + param_values = [] + param_names = value_list["parameter_names"] + combos = value_list["combos"] + for combo in combos: + parameter_values = combo["parameter_values"] + if "suffix" in combo: + suffix = combo["suffix"] + else: + suffix = "" + for param_value in parameter_values: + if len(str(param_value)) > 0: + suffix += "_" + str(param_value) + suffix = suffix[1:] + param_values.append((param_names, suffix, parameter_values)) + + all_iterated_params.append(param_values) + + elif param_name not in exclude_params: param_values = [] for value in value_list: if "RANGE" in value: @@ -713,7 +731,7 @@ def generateVariantCombinations( return list(product(*all_iterated_params)) - def parseTemplateYaml(self, yaml_file: str) -> None: + def parseTemplateYaml(self, yaml_file: str) -> None: # noqa: C901 with open(yaml_file) as f: contents = yaml.load(f, Loader=UniqueKeyLoader) for template_name, params_dict in contents.items(): @@ -762,10 +780,21 @@ def parseTemplateYaml(self, yaml_file: str) -> None: default_params_copy[key] = variant[key] variant_name = variant["NAME"] - for param_value in combination: - default_params_copy[param_value[0]] = param_value[2] - if len(str(param_value[1])) > 0: - variant_name = f"{variant_name}_{param_value[1]}" + + for setting in combination: + param_names = setting[0] + suffix = setting[1] + param_values = setting[2] + if isinstance(param_names, list): + for param_name, param_value in zip( + param_names, param_values + ): + default_params_copy[param_name] = param_value + else: + default_params_copy[param_names] = param_values + + if len(str(suffix)) > 0: + variant_name = f"{variant_name}_{suffix}" default_params_copy["NAME"] = variant_name default_params_copy["VARIANT_NAME"] = variant["NAME"] @@ -1104,6 +1133,8 @@ class ShaderInfo: requires_16bit_storage_ext: bool = False requires_8bit_storage_ext: bool = False requires_integer_dot_product_ext: bool = False + requires_shader_int64_ext: bool = False + requires_shader_float64_ext: bool = False def getName(filePath: str) -> str: @@ -1193,7 +1224,7 @@ def determineDescriptorType(lineStr: str) -> str: ) -def getShaderInfo(srcFilePath: str) -> ShaderInfo: +def getShaderInfo(srcFilePath: str) -> ShaderInfo: # noqa: C901 shader_info = ShaderInfo([], [], "") with open(srcFilePath) as srcFile: for line in srcFile: @@ -1216,6 +1247,10 @@ def getShaderInfo(srcFilePath: str) -> ShaderInfo: shader_info.requires_8bit_storage_ext = True if "GL_EXT_integer_dot_product" in line: shader_info.requires_integer_dot_product_ext = True + if "GL_EXT_shader_explicit_arithmetic_types_int64" in line: + shader_info.requires_shader_int64_ext = True + if "GL_EXT_shader_explicit_arithmetic_types_float64" in line: + shader_info.requires_shader_float64_ext = True return shader_info @@ -1292,6 +1327,8 @@ def to_cpp_str(val: bool): to_cpp_str(shader_info.requires_16bit_storage_ext), to_cpp_str(shader_info.requires_8bit_storage_ext), to_cpp_str(shader_info.requires_integer_dot_product_ext), + to_cpp_str(shader_info.requires_shader_int64_ext), + to_cpp_str(shader_info.requires_shader_float64_ext), ] shader_info_str = textwrap.indent( diff --git a/backends/vulkan/runtime/graph/ComputeGraph.cpp b/backends/vulkan/runtime/graph/ComputeGraph.cpp index 6609298b0d8..2ec63a89df8 100644 --- a/backends/vulkan/runtime/graph/ComputeGraph.cpp +++ b/backends/vulkan/runtime/graph/ComputeGraph.cpp @@ -310,6 +310,8 @@ vkapi::ScalarType ComputeGraph::dtype_of(const ValueRef idx) const { return val.toConstTensor().dtype(); } else if (val.isTensorRef()) { return val.toConstTensorRef().dtype; + } else if (val.isStaging()) { + return val.toConstStaging().dtype(); } else if (val.isBool()) { return vkapi::ScalarType::Bool; } else if (val.isDouble()) { @@ -585,21 +587,45 @@ ValueRef ComputeGraph::get_or_add_value_for_int(const int64_t val) { return add_scalar(val); } +ValueRef ComputeGraph::set_input_tensor( + const ValueRef idx, + vkapi::ScalarType staging_dtype) { + // For texture storage, the buffer size needs to account for the zero + // padding applied by unused texel elements. + size_t buf_numel = get_tensor(idx)->staging_buffer_numel(); + ValueRef staging_idx = add_staging(staging_dtype, buf_numel); + add_staging_to_tensor_node(*this, staging_idx, idx); + inputs_.push_back({idx, staging_idx}); + return staging_idx; +} + ValueRef ComputeGraph::set_input_tensor( const ValueRef idx, const bool use_staging) { if (use_staging) { vkapi::ScalarType dtype = get_tensor(idx)->dtype(); - // For texture storage, the buffer size needs to account for the zero - // padding applied by unused texel elements. - size_t buf_numel = get_tensor(idx)->staging_buffer_numel(); - ValueRef staging_idx = add_staging(dtype, buf_numel); - add_staging_to_tensor_node(*this, staging_idx, idx); - inputs_.push_back({idx, staging_idx}); - return staging_idx; - } - inputs_.push_back({idx, kDummyValueRef}); - return idx; + return set_input_tensor(idx, dtype); + } else { + inputs_.push_back({idx, kDummyValueRef}); + return idx; + } +} + +ValueRef ComputeGraph::set_output_tensor( + const ValueRef idx, + vkapi::ScalarType staging_dtype) { + // For texture storage, the buffer size needs to account for the zero + // padding applied by unused texel elements. + size_t buf_numel = get_tensor(idx)->staging_buffer_numel(); + ValueRef staging_idx = add_staging(staging_dtype, buf_numel); + // We only run this when the tensor is non-empty. When the underlying + // tensor is empty (e.g. padded_numel == 0), we do not allocate a VkImage to + // tensor, we will not be able to bind the node for execution. + if (buf_numel > 0) { + add_tensor_to_staging_node(*this, idx, staging_idx); + } + outputs_.push_back({idx, staging_idx}); + return staging_idx; } ValueRef ComputeGraph::set_output_tensor( @@ -607,21 +633,11 @@ ValueRef ComputeGraph::set_output_tensor( const bool use_staging) { if (use_staging) { vkapi::ScalarType dtype = get_tensor(idx)->dtype(); - // For texture storage, the buffer size needs to account for the zero - // padding applied by unused texel elements. - size_t buf_numel = get_tensor(idx)->staging_buffer_numel(); - ValueRef staging_idx = add_staging(dtype, buf_numel); - // We only run this when the tensor is non-empty. When the underlying - // tensor is empty (e.g. padded_numel == 0), we do not allocate a VkImage to - // tensor, we will not be able to bind the node for execution. - if (buf_numel > 0) { - add_tensor_to_staging_node(*this, idx, staging_idx); - } - outputs_.push_back({idx, staging_idx}); - return staging_idx; + return set_output_tensor(idx, dtype); + } else { + outputs_.push_back({idx, kDummyValueRef}); + return idx; } - outputs_.push_back({idx, kDummyValueRef}); - return idx; } ValueRef ComputeGraph::set_output_value(const ValueRef idx) { @@ -847,6 +863,36 @@ void ComputeGraph::copy_into_staging( staging->copy_from(data, nbytes); } +void ComputeGraph::maybe_cast_and_copy_into_staging( + const ValueRef idx, + const void* data, + const size_t numel, + const vkapi::ScalarType src_data_dtype) { + StagingPtr staging = get_staging(idx); + vkapi::ScalarType staging_dtype = staging->dtype(); + if (src_data_dtype == staging_dtype) { + size_t nbytes = numel * vkapi::element_size(staging_dtype); + staging->copy_from(data, nbytes); + return; + } else { + // Hard-coded type conversion cases + if (src_data_dtype == vkapi::kLong && staging_dtype == vkapi::kInt) { + const int64_t* casted_data = reinterpret_cast(data); + staging->cast_and_copy_from(casted_data, numel); + } else if ( + src_data_dtype == vkapi::kDouble && staging_dtype == vkapi::kFloat) { + const double* casted_data = reinterpret_cast(data); + staging->cast_and_copy_from(casted_data, numel); + } else { + VK_THROW( + "Unsupported type conversion from ", + src_data_dtype, + " to staging dtype ", + staging_dtype); + } + } +} + void ComputeGraph::copy_from_staging( const ValueRef idx, void* data, @@ -856,6 +902,36 @@ void ComputeGraph::copy_from_staging( staging->copy_to(data, nbytes); } +void ComputeGraph::maybe_cast_and_copy_from_staging( + const ValueRef idx, + void* data, + const size_t numel, + const vkapi::ScalarType dst_data_dtype) { + StagingPtr staging = get_staging(idx); + vkapi::ScalarType staging_dtype = staging->dtype(); + if (dst_data_dtype == staging_dtype) { + size_t nbytes = numel * vkapi::element_size(staging_dtype); + staging->copy_to(data, nbytes); + return; + } else { + // Hard-coded type conversion cases + if (dst_data_dtype == vkapi::kLong && staging_dtype == vkapi::kInt) { + int64_t* casted_data = reinterpret_cast(data); + staging->cast_and_copy_to(casted_data, numel); + } else if ( + dst_data_dtype == vkapi::kDouble && staging_dtype == vkapi::kFloat) { + double* casted_data = reinterpret_cast(data); + staging->cast_and_copy_to(casted_data, numel); + } else { + VK_THROW( + "Unsupported type conversion from staging dtype ", + staging_dtype, + " to ", + dst_data_dtype); + } + } +} + void ComputeGraph::prepare() { #define MERGE_FIELD(field) \ static_cast(std::ceil( \ diff --git a/backends/vulkan/runtime/graph/ComputeGraph.h b/backends/vulkan/runtime/graph/ComputeGraph.h index 23b5517fd22..baa15233a00 100644 --- a/backends/vulkan/runtime/graph/ComputeGraph.h +++ b/backends/vulkan/runtime/graph/ComputeGraph.h @@ -771,7 +771,16 @@ class ComputeGraph final { */ ValueRef get_or_add_value_for_int(const int64_t val); + ValueRef set_input_tensor( + const ValueRef idx, + vkapi::ScalarType staging_dtype); + ValueRef set_input_tensor(const ValueRef idx, const bool use_staging = true); + + ValueRef set_output_tensor( + const ValueRef idx, + vkapi::ScalarType staging_dtype); + ValueRef set_output_tensor(const ValueRef idx, const bool use_staging = true); ValueRef set_output_value(const ValueRef idx); @@ -947,8 +956,21 @@ class ComputeGraph final { void copy_into_staging(const ValueRef idx, const void* data, const size_t numel); + + void maybe_cast_and_copy_into_staging( + const ValueRef idx, + const void* data, + const size_t numel, + const vkapi::ScalarType src_data_dtype); + void copy_from_staging(const ValueRef idx, void* data, const size_t numel); + void maybe_cast_and_copy_from_staging( + const ValueRef idx, + void* data, + const size_t numel, + const vkapi::ScalarType dst_data_dtype); + protected: // Command Buffer Management diff --git a/backends/vulkan/runtime/graph/ops/glsl/buffer_to_nchw.glsl b/backends/vulkan/runtime/graph/ops/glsl/buffer_to_nchw.glsl index 6d164ae2645..f61081d33b7 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/buffer_to_nchw.glsl +++ b/backends/vulkan/runtime/graph/ops/glsl/buffer_to_nchw.glsl @@ -3,14 +3,16 @@ #define PRECISION ${PRECISION} #define T ${buffer_scalar_type(DTYPE)} +#define DST_T ${buffer_scalar_type(BUF_DTYPE)} ${define_required_extensions(DTYPE)} +${define_required_extensions(BUF_DTYPE)} layout(std430) buffer; #include "indexing.glslh" -${layout_declare_tensor(B, "w", "nchw_buf", DTYPE, STORAGE)} +${layout_declare_tensor(B, "w", "nchw_buf", BUF_DTYPE, STORAGE)} ${layout_declare_tensor(B, "r", "t_inp", DTYPE, STORAGE)} ${layout_declare_ubo(B, "BufferMetadata", "inp")} @@ -32,5 +34,5 @@ void main() { uint nchwi = tensor_idx_to_contiguous_idx(inp, inp_tidx); - nchw_buf[nchwi] = t_inp[inp_bufi]; + nchw_buf[nchwi] = DST_T(t_inp[inp_bufi]); } diff --git a/backends/vulkan/runtime/graph/ops/glsl/buffer_to_nchw.yaml b/backends/vulkan/runtime/graph/ops/glsl/buffer_to_nchw.yaml index 929108cca5e..1ee7d2db8c1 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/buffer_to_nchw.yaml +++ b/backends/vulkan/runtime/graph/ops/glsl/buffer_to_nchw.yaml @@ -7,15 +7,19 @@ buffer_to_nchw: parameter_names_with_default_values: DTYPE: float + BUF_DTYPE: float STORAGE: buffer USE_PUSH_CONST: True generate_variant_forall: - DTYPE: - - VALUE: half - - VALUE: float - - VALUE: double - - VALUE: int8 - - VALUE: uint8 - - VALUE: int32 + combination: + parameter_names: [DTYPE, BUF_DTYPE] + combos: + - parameter_values: [half, half] + - parameter_values: [half, float] + - parameter_values: [float, float] + - parameter_values: [double, double] + - parameter_values: [int8, int8] + - parameter_values: [uint8, uint8] + - parameter_values: [int32, int32] shader_variants: - NAME: buffer_to_nchw diff --git a/backends/vulkan/runtime/graph/ops/glsl/image_to_nchw.glsl b/backends/vulkan/runtime/graph/ops/glsl/image_to_nchw.glsl index d7bef9f0163..1498ed01aef 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/image_to_nchw.glsl +++ b/backends/vulkan/runtime/graph/ops/glsl/image_to_nchw.glsl @@ -16,10 +16,11 @@ ${define_active_storage_type(STORAGE)} ${define_required_extensions(DTYPE)} +${define_required_extensions(BUF_DTYPE)} layout(std430) buffer; -${layout_declare_buffer(B, "w", "buf_out", DTYPE)} +${layout_declare_buffer(B, "w", "buf_out", BUF_DTYPE)} ${layout_declare_tensor(B, "r", "t_in", DTYPE, STORAGE)} $if USE_PUSH_CONST: diff --git a/backends/vulkan/runtime/graph/ops/glsl/image_to_nchw.yaml b/backends/vulkan/runtime/graph/ops/glsl/image_to_nchw.yaml index 646d8f1be81..ebbc55dd9dc 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/image_to_nchw.yaml +++ b/backends/vulkan/runtime/graph/ops/glsl/image_to_nchw.yaml @@ -7,17 +7,21 @@ image_to_nchw: parameter_names_with_default_values: DTYPE: float + BUF_DTYPE: float STORAGE: texture3d TO_STAGING: True USE_PUSH_CONST: True generate_variant_forall: - DTYPE: - - VALUE: half - - VALUE: float - - VALUE: double - - VALUE: int8 - - VALUE: uint8 - - VALUE: int32 + combination: + parameter_names: [DTYPE, BUF_DTYPE] + combos: + - parameter_values: [half, half] + - parameter_values: [half, float] + - parameter_values: [float, float] + - parameter_values: [double, double] + - parameter_values: [int8, int8] + - parameter_values: [uint8, uint8] + - parameter_values: [int32, int32] shader_variants: - NAME: image_to_nchw_texture3d - NAME: image_to_nchw_texture2d diff --git a/backends/vulkan/runtime/graph/ops/glsl/nchw_to_buffer.glsl b/backends/vulkan/runtime/graph/ops/glsl/nchw_to_buffer.glsl index 074624dc37e..a16f5405cbb 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/nchw_to_buffer.glsl +++ b/backends/vulkan/runtime/graph/ops/glsl/nchw_to_buffer.glsl @@ -5,13 +5,14 @@ #define T ${buffer_scalar_type(DTYPE)} ${define_required_extensions(DTYPE)} +${define_required_extensions(BUF_DTYPE)} layout(std430) buffer; #include "indexing.glslh" ${layout_declare_tensor(B, "w", "t_outp", DTYPE, STORAGE)} -${layout_declare_tensor(B, "r", "nchw_in", DTYPE, STORAGE)} +${layout_declare_tensor(B, "r", "nchw_in", BUF_DTYPE, STORAGE)} ${layout_declare_ubo(B, "BufferMetadata", "outp")} @@ -44,5 +45,5 @@ void main() { nchwi = tensor_idx_to_contiguous_idx(outp, outp_tidx); } - t_outp[outp_bufi] = nchw_in[nchwi]; + t_outp[outp_bufi] = T(nchw_in[nchwi]); } diff --git a/backends/vulkan/runtime/graph/ops/glsl/nchw_to_buffer.yaml b/backends/vulkan/runtime/graph/ops/glsl/nchw_to_buffer.yaml index 9d6c3aa76a9..602fd1bc65a 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/nchw_to_buffer.yaml +++ b/backends/vulkan/runtime/graph/ops/glsl/nchw_to_buffer.yaml @@ -7,15 +7,19 @@ nchw_to_buffer: parameter_names_with_default_values: DTYPE: float + BUF_DTYPE: float STORAGE: buffer USE_PUSH_CONST: True generate_variant_forall: - DTYPE: - - VALUE: half - - VALUE: float - - VALUE: double - - VALUE: int8 - - VALUE: uint8 - - VALUE: int32 + combination: + parameter_names: [DTYPE, BUF_DTYPE] + combos: + - parameter_values: [half, half] + - parameter_values: [half, float] + - parameter_values: [float, float] + - parameter_values: [double, double] + - parameter_values: [int8, int8] + - parameter_values: [uint8, uint8] + - parameter_values: [int32, int32] shader_variants: - NAME: nchw_to_buffer diff --git a/backends/vulkan/runtime/graph/ops/glsl/nchw_to_image.glsl b/backends/vulkan/runtime/graph/ops/glsl/nchw_to_image.glsl index f3f604e10cd..15676fb0500 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/nchw_to_image.glsl +++ b/backends/vulkan/runtime/graph/ops/glsl/nchw_to_image.glsl @@ -16,11 +16,12 @@ ${define_active_storage_type(STORAGE)} ${define_required_extensions(DTYPE)} +${define_required_extensions(BUF_DTYPE)} layout(std430) buffer; ${layout_declare_tensor(B, "w", "t_out", DTYPE, STORAGE)} -${layout_declare_buffer(B, "r", "buf_in", DTYPE)} +${layout_declare_buffer(B, "r", "buf_in", BUF_DTYPE)} $if USE_PUSH_CONST: layout(push_constant) uniform restrict Block { diff --git a/backends/vulkan/runtime/graph/ops/glsl/nchw_to_image.yaml b/backends/vulkan/runtime/graph/ops/glsl/nchw_to_image.yaml index 85119c8d508..f6809e4024a 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/nchw_to_image.yaml +++ b/backends/vulkan/runtime/graph/ops/glsl/nchw_to_image.yaml @@ -11,13 +11,16 @@ nchw_to_image: FROM_STAGING: True USE_PUSH_CONST: True generate_variant_forall: - DTYPE: - - VALUE: half - - VALUE: float - - VALUE: double - - VALUE: int8 - - VALUE: uint8 - - VALUE: int32 + combination: + parameter_names: [DTYPE, BUF_DTYPE] + combos: + - parameter_values: [half, half] + - parameter_values: [half, float] + - parameter_values: [float, float] + - parameter_values: [double, double] + - parameter_values: [int8, int8] + - parameter_values: [uint8, uint8] + - parameter_values: [int32, int32] shader_variants: - NAME: nchw_to_image_texture3d - NAME: nchw_to_image_texture2d diff --git a/backends/vulkan/runtime/graph/ops/impl/Clone.cpp b/backends/vulkan/runtime/graph/ops/impl/Clone.cpp index 0ae9d53a481..059aade5b04 100644 --- a/backends/vulkan/runtime/graph/ops/impl/Clone.cpp +++ b/backends/vulkan/runtime/graph/ops/impl/Clone.cpp @@ -76,6 +76,7 @@ void add_image_to_buffer_node( const ValueRef buffer) { std::string kernel_name = "clone_image_to_buffer"; add_dtype_suffix(kernel_name, graph.dtype_of(image)); + add_dtype_suffix(kernel_name, graph.dtype_of(buffer)); vkapi::ShaderInfo shader = VK_KERNEL_FROM_STR(kernel_name); graph.execute_nodes().emplace_back(new DynamicDispatchNode( @@ -103,6 +104,7 @@ void add_buffer_to_image_node( const ValueRef image) { std::string kernel_name = "clone_buffer_to_image"; add_dtype_suffix(kernel_name, graph.dtype_of(image)); + add_dtype_suffix(kernel_name, graph.dtype_of(buffer)); vkapi::ShaderInfo shader = VK_KERNEL_FROM_STR(kernel_name); graph.execute_nodes().emplace_back(new DynamicDispatchNode( diff --git a/backends/vulkan/runtime/graph/ops/impl/Convolution.cpp b/backends/vulkan/runtime/graph/ops/impl/Convolution.cpp index ded1defe973..b83164f27d2 100644 --- a/backends/vulkan/runtime/graph/ops/impl/Convolution.cpp +++ b/backends/vulkan/runtime/graph/ops/impl/Convolution.cpp @@ -105,7 +105,8 @@ ValueRef prepack_biases( ValueRef v = graph.add_tensor( {out_channels}, graph.dtype_of(weight), storage_type, memory_layout); - vkapi::ShaderInfo shader = get_nchw_to_tensor_shader(graph, v); + vkapi::ShaderInfo shader = + get_nchw_to_tensor_shader(graph, v, graph.dtype_of(weight)); graph.prepack_nodes().emplace_back(new PrepackNode( graph, diff --git a/backends/vulkan/runtime/graph/ops/impl/Staging.cpp b/backends/vulkan/runtime/graph/ops/impl/Staging.cpp index 648d7b8da09..40de9b59e81 100644 --- a/backends/vulkan/runtime/graph/ops/impl/Staging.cpp +++ b/backends/vulkan/runtime/graph/ops/impl/Staging.cpp @@ -27,7 +27,10 @@ void add_staging_to_tensor_node( VK_CHECK_COND(graph.val_is_staging(in_staging)); vkapi::ShaderInfo shader = get_nchw_to_tensor_shader( - graph, out_tensor, graph.int8_buffers_enabled()); + graph, + out_tensor, + graph.dtype_of(in_staging), + graph.int8_buffers_enabled()); vkapi::ParamsBindList param_buffers = {}; if (graph.is_buffer_storage(out_tensor)) { @@ -66,16 +69,6 @@ bool is_bitw8_shader(const vkapi::ShaderInfo& shader) { return shader_prefix_str == kBitw8PrefixStr; } -vkapi::ShaderInfo get_tensor_to_staging_shader( - ComputeGraph* graph, - const std::vector& args, - const std::vector& resize_args) { - (void)resize_args; - const ValueRef in_tensor = args.at(1).refs.at(0); - return get_tensor_to_nchw_shader( - *graph, in_tensor, graph->int8_buffers_enabled()); -} - utils::uvec3 tensor_to_staging_global_wg_size( ComputeGraph* graph, const vkapi::ShaderInfo& shader, @@ -110,8 +103,11 @@ void add_tensor_to_staging_node( const ValueRef out_staging) { VK_CHECK_COND(graph.val_is_staging(out_staging)); - vkapi::ShaderInfo shader = - get_tensor_to_nchw_shader(graph, in_tensor, graph.int8_buffers_enabled()); + vkapi::ShaderInfo shader = get_tensor_to_nchw_shader( + graph, + in_tensor, + graph.dtype_of(out_staging), + graph.int8_buffers_enabled()); vkapi::ParamsBindList param_buffers = {}; if (graph.is_buffer_storage(in_tensor)) { @@ -151,8 +147,8 @@ void add_prepack_standard_node( const ValueRef tensor_data, const ValueRef tensor, const bool transpose_hw = false) { - vkapi::ShaderInfo shader = - get_nchw_to_tensor_shader(graph, tensor, graph.int8_buffers_enabled()); + vkapi::ShaderInfo shader = get_nchw_to_tensor_shader( + graph, tensor, graph.dtype_of(tensor_data), graph.int8_buffers_enabled()); vkapi::ParamsBindList param_buffers = {}; if (graph.is_buffer_storage(tensor)) { diff --git a/backends/vulkan/runtime/graph/ops/utils/StagingUtils.cpp b/backends/vulkan/runtime/graph/ops/utils/StagingUtils.cpp index c90bfa402bb..c2adca526fb 100644 --- a/backends/vulkan/runtime/graph/ops/utils/StagingUtils.cpp +++ b/backends/vulkan/runtime/graph/ops/utils/StagingUtils.cpp @@ -23,6 +23,7 @@ bool is_bitw8(vkapi::ScalarType dtype) { vkapi::ShaderInfo get_nchw_to_tensor_shader( ComputeGraph& graph, const ValueRef dst, + const vkapi::ScalarType staging_dtype, bool int8_buffer_enabled, bool push_constant_variant) { std::string kernel_name; @@ -45,6 +46,7 @@ vkapi::ShaderInfo get_nchw_to_tensor_shader( if (dst_storage_type == utils::kBuffer) { kernel_name = "nchw_to_buffer"; add_dtype_suffix(kernel_name, dst_dtype); + add_dtype_suffix(kernel_name, staging_dtype); return VK_KERNEL_FROM_STR(kernel_name); } @@ -54,6 +56,7 @@ vkapi::ShaderInfo get_nchw_to_tensor_shader( } add_storage_type_suffix(kernel_name, dst_storage_type); add_dtype_suffix(kernel_name, dst_dtype); + add_dtype_suffix(kernel_name, staging_dtype); return VK_KERNEL_FROM_STR(kernel_name); } @@ -61,6 +64,7 @@ vkapi::ShaderInfo get_nchw_to_tensor_shader( vkapi::ShaderInfo get_tensor_to_nchw_shader( ComputeGraph& graph, const ValueRef src, + const vkapi::ScalarType staging_dtype, bool int8_buffer_enabled, bool push_constant_variant) { std::string kernel_name; @@ -83,6 +87,7 @@ vkapi::ShaderInfo get_tensor_to_nchw_shader( if (src_storage_type == utils::kBuffer) { kernel_name = "buffer_to_nchw"; add_dtype_suffix(kernel_name, src_dtype); + add_dtype_suffix(kernel_name, staging_dtype); return VK_KERNEL_FROM_STR(kernel_name); } @@ -92,6 +97,7 @@ vkapi::ShaderInfo get_tensor_to_nchw_shader( } add_storage_type_suffix(kernel_name, src_storage_type); add_dtype_suffix(kernel_name, src_dtype); + add_dtype_suffix(kernel_name, staging_dtype); return VK_KERNEL_FROM_STR(kernel_name); } diff --git a/backends/vulkan/runtime/graph/ops/utils/StagingUtils.h b/backends/vulkan/runtime/graph/ops/utils/StagingUtils.h index 71c92b833b7..a4419de3932 100644 --- a/backends/vulkan/runtime/graph/ops/utils/StagingUtils.h +++ b/backends/vulkan/runtime/graph/ops/utils/StagingUtils.h @@ -15,11 +15,14 @@ namespace vkcompute { vkapi::ShaderInfo get_nchw_to_tensor_shader( ComputeGraph& graph, const ValueRef dst, + const vkapi::ScalarType staging_dtype, bool int8_buffer_enabled = true, bool push_constant_variant = true); + vkapi::ShaderInfo get_tensor_to_nchw_shader( ComputeGraph& graph, const ValueRef src, + const vkapi::ScalarType staging_dtype, bool int8_buffer_enabled = true, bool push_constant_variant = true); diff --git a/backends/vulkan/runtime/vk_api/Adapter.cpp b/backends/vulkan/runtime/vk_api/Adapter.cpp index 0e87dde1922..aa76b202882 100644 --- a/backends/vulkan/runtime/vk_api/Adapter.cpp +++ b/backends/vulkan/runtime/vk_api/Adapter.cpp @@ -11,6 +11,7 @@ #include #include +#include namespace vkcompute { namespace vkapi { @@ -412,6 +413,11 @@ std::string Adapter::stringize() const { #endif /* VK_KHR_shader_float16_int8 */ ss << " }" << std::endl; + ss << " Shader 64bit Features {" << std::endl; + PRINT_BOOL(physical_device_.supports_int64_shader_types, shaderInt64) + PRINT_BOOL(physical_device_.supports_float64_shader_types, shaderFloat64) + ss << " }" << std::endl; + #ifdef VK_KHR_shader_integer_dot_product ss << " Shader Integer Dot Product Features {" << std::endl; PRINT_PROP( diff --git a/backends/vulkan/runtime/vk_api/Adapter.h b/backends/vulkan/runtime/vk_api/Adapter.h index 6a68b487348..65d0977b533 100644 --- a/backends/vulkan/runtime/vk_api/Adapter.h +++ b/backends/vulkan/runtime/vk_api/Adapter.h @@ -225,6 +225,14 @@ class Adapter final { return physical_device_.supports_int16_shader_types; } + inline bool supports_int64_shader_types() { + return physical_device_.supports_int64_shader_types; + } + + inline bool supports_float64_shader_types() { + return physical_device_.supports_float64_shader_types; + } + inline bool has_full_float16_buffers_support() { return supports_16bit_storage_buffers() && supports_float16_shader_types(); } diff --git a/backends/vulkan/runtime/vk_api/Device.cpp b/backends/vulkan/runtime/vk_api/Device.cpp index a21130f1231..7a3a825f5ec 100644 --- a/backends/vulkan/runtime/vk_api/Device.cpp +++ b/backends/vulkan/runtime/vk_api/Device.cpp @@ -16,6 +16,7 @@ #include #include #include +#include namespace vkcompute { namespace vkapi { @@ -45,6 +46,8 @@ PhysicalDevice::PhysicalDevice(VkPhysicalDevice physical_device_handle) queue_families{}, num_compute_queues(0), supports_int16_shader_types(false), + supports_int64_shader_types(false), + supports_float64_shader_types(false), has_unified_memory(false), has_timestamps(false), timestamp_period(0), @@ -97,6 +100,12 @@ PhysicalDevice::PhysicalDevice(VkPhysicalDevice physical_device_handle) if (features2.features.shaderInt16 == VK_TRUE) { supports_int16_shader_types = true; } + if (features2.features.shaderInt64 == VK_TRUE) { + supports_int64_shader_types = true; + } + if (features2.features.shaderFloat64 == VK_TRUE) { + supports_float64_shader_types = true; + } // Check if there are any memory types have both the HOST_VISIBLE and the // DEVICE_LOCAL property flags diff --git a/backends/vulkan/runtime/vk_api/Device.h b/backends/vulkan/runtime/vk_api/Device.h index f5b7154d260..917df514c4b 100644 --- a/backends/vulkan/runtime/vk_api/Device.h +++ b/backends/vulkan/runtime/vk_api/Device.h @@ -12,7 +12,7 @@ #include -#include +#include #include namespace vkcompute { @@ -57,6 +57,8 @@ struct PhysicalDevice final { // Metadata uint32_t num_compute_queues; bool supports_int16_shader_types; + bool supports_int64_shader_types; + bool supports_float64_shader_types; bool has_unified_memory; bool has_timestamps; float timestamp_period; diff --git a/backends/vulkan/runtime/vk_api/Exception.cpp b/backends/vulkan/runtime/vk_api/Exception.cpp index c07349fa7ca..d3efa81e52a 100644 --- a/backends/vulkan/runtime/vk_api/Exception.cpp +++ b/backends/vulkan/runtime/vk_api/Exception.cpp @@ -95,6 +95,12 @@ std::ostream& operator<<(std::ostream& out, const VulkanExtension result) { case VulkanExtension::INTEGER_DOT_PRODUCT: out << "VK_KHR_shader_integer_dot_product"; break; + case VulkanExtension::SHADER_INT64: + out << "shaderInt64"; + break; + case VulkanExtension::SHADER_FLOAT64: + out << "shaderFloat64"; + break; } return out; } diff --git a/backends/vulkan/runtime/vk_api/Exception.h b/backends/vulkan/runtime/vk_api/Exception.h index a883a68fefc..aa1ef1f2526 100644 --- a/backends/vulkan/runtime/vk_api/Exception.h +++ b/backends/vulkan/runtime/vk_api/Exception.h @@ -83,6 +83,8 @@ enum class VulkanExtension : uint8_t { INT16_STORAGE, INT8_STORAGE, INTEGER_DOT_PRODUCT, + SHADER_INT64, + SHADER_FLOAT64, }; class ShaderNotSupportedError : public std::exception { diff --git a/backends/vulkan/runtime/vk_api/Shader.cpp b/backends/vulkan/runtime/vk_api/Shader.cpp index 4356f92efe7..c932d0a264b 100644 --- a/backends/vulkan/runtime/vk_api/Shader.cpp +++ b/backends/vulkan/runtime/vk_api/Shader.cpp @@ -32,7 +32,9 @@ ShaderInfo::ShaderInfo( const bool requires_shader_int16_ext, const bool requires_16bit_storage_ext, const bool requires_8bit_storage_ext, - const bool requires_integer_dot_product_ext) + const bool requires_integer_dot_product_ext, + const bool requires_shader_int64_ext, + const bool requires_shader_float64_ext) : src_code{ spirv_bin, size, @@ -43,7 +45,9 @@ ShaderInfo::ShaderInfo( requires_shader_int16(requires_shader_int16_ext), requires_16bit_storage(requires_16bit_storage_ext), requires_8bit_storage(requires_8bit_storage_ext), - requires_integer_dot_product(requires_integer_dot_product_ext) { + requires_integer_dot_product(requires_integer_dot_product_ext), + requires_shader_int64(requires_shader_int64_ext), + requires_shader_float64(requires_shader_float64_ext) { } bool operator==(const ShaderInfo& _1, const ShaderInfo& _2) { diff --git a/backends/vulkan/runtime/vk_api/Shader.h b/backends/vulkan/runtime/vk_api/Shader.h index 21332381406..6311710f02b 100644 --- a/backends/vulkan/runtime/vk_api/Shader.h +++ b/backends/vulkan/runtime/vk_api/Shader.h @@ -66,6 +66,8 @@ struct ShaderInfo final { bool requires_16bit_storage = false; bool requires_8bit_storage = false; bool requires_integer_dot_product = false; + bool requires_shader_int64 = false; + bool requires_shader_float64 = false; explicit ShaderInfo(); @@ -78,7 +80,9 @@ struct ShaderInfo final { const bool requires_shader_int16_ext, const bool requires_16bit_storage_ext, const bool requires_8bit_storage_ext, - const bool requires_integer_dot_product_ext); + const bool requires_integer_dot_product_ext, + const bool requires_shader_int64_ext, + const bool requires_shader_float64_ext); operator bool() const { return src_code.bin != nullptr; diff --git a/backends/vulkan/serialization/schema.fbs b/backends/vulkan/serialization/schema.fbs index b6670b6f53d..4bc12208ce7 100644 --- a/backends/vulkan/serialization/schema.fbs +++ b/backends/vulkan/serialization/schema.fbs @@ -20,6 +20,7 @@ enum VkDataType : byte { FLOAT32 = 5, FLOAT64 = 6, INT64 = 7, + UNSET = 127, } // Describes what kind of GPU resource should be used to represent a tensor. The @@ -55,6 +56,9 @@ table VkTensor { storage_type:VkStorageType = DEFAULT_STORAGE; // Memory layout that should be used to represent this tensor memory_layout:VkMemoryLayout = DEFAULT_LAYOUT; + // dtype to use for staging buffer. This may be different from the tensor's datatype + // if force_fp16 is enabled to force all float tensors to be represented as fp16. + staging_datatype:VkDataType = UNSET; } table Null {} diff --git a/backends/vulkan/serialization/vulkan_graph_builder.py b/backends/vulkan/serialization/vulkan_graph_builder.py index 78ac51c8808..43ea6c7ce30 100644 --- a/backends/vulkan/serialization/vulkan_graph_builder.py +++ b/backends/vulkan/serialization/vulkan_graph_builder.py @@ -50,10 +50,12 @@ def __init__( program: ExportedProgram, delegate_mapping_builder: DelegateMappingBuilder, downcast_64_bit: bool = True, + force_fp16: bool = False, ) -> None: self.program = program self.delegate_mapping_builder = delegate_mapping_builder self.downcast_64_bit = downcast_64_bit + self.force_fp16 = force_fp16 self.chain = [] self.values = [] self.input_ids = [] @@ -135,6 +137,12 @@ def maybe_add_constant_tensor(self, node: Node) -> int: if is_param_node(self.program, node): tensor = self.get_param_tensor(node) + effective_dtype = self.get_effective_dtype(tensor.dtype) + + # Convert the tensor dtype if needed + if tensor.dtype != effective_dtype: + tensor = tensor.to(effective_dtype) + # Serialize tensor data to bytes tensor = tensor.contiguous() size = tensor.untyped_storage().nbytes() @@ -222,6 +230,29 @@ def create_symint_value(self) -> int: self.values.append(vk_graph_schema.VkValue(vk_graph_schema.SymInt(0))) return new_id + def get_effective_dtype(self, dtype: torch.dtype) -> torch.dtype: + if self.downcast_64_bit and dtype == torch.float64: + return torch.float32 + elif self.downcast_64_bit and dtype == torch.int64: + return torch.int32 + elif self.force_fp16 and dtype == torch.float32: + return torch.float16 + else: + return dtype + + def get_staging_dtype(self, dtype: torch.dtype) -> torch.dtype: + # Since 64 bit types are not guaranteed to be supported on all GPUs, + # the conversion between 32 bit and 64 bit types is handled on the CPU + # side. The conversion will occur when copying the staging buffer + # contents to/from ETensor data pointers, rather than in the shader to + # copy between GPU buffer/image to staging buffer. + if self.downcast_64_bit and dtype == torch.float64: + return torch.float32 + elif self.downcast_64_bit and dtype == torch.int64: + return torch.int32 + else: + return dtype + def create_tensor_value(self, spec: TensorSpec, constant_id: int = -1) -> int: # Negative id indicates that this tensor will have its own dedicated memory. mem_obj_id = -1 @@ -236,14 +267,16 @@ def create_tensor_value(self, spec: TensorSpec, constant_id: int = -1) -> int: storage_type = spec.etvk_node_repr.storage_type memory_layout = spec.etvk_node_repr.memory_layout - # Apply downcast logic before getting VK datatype - effective_dtype = spec.dtype - if self.downcast_64_bit and spec.dtype == torch.float64: - effective_dtype = torch.float32 - elif self.downcast_64_bit and spec.dtype == torch.int64: - effective_dtype = torch.int32 + effective_dtype = self.get_effective_dtype(spec.dtype) + # For constant tensors, the datatype of the original tensor will have been + # converted to the effective dtype. Otherwise, the type of the staging buffer + # for inputs/outputs should match the original tensor dtype. + staging_dtype = ( + effective_dtype if constant_id >= 0 else self.get_staging_dtype(spec.dtype) + ) datatype = self.get_vk_datatype(effective_dtype) + staging_datatype = self.get_vk_datatype(staging_dtype) new_id = len(self.values) self.values.append( @@ -255,6 +288,7 @@ def create_tensor_value(self, spec: TensorSpec, constant_id: int = -1) -> int: mem_obj_id=mem_obj_id, storage_type=storage_type, memory_layout=memory_layout, + staging_datatype=staging_datatype, ) ) ) diff --git a/backends/vulkan/serialization/vulkan_graph_schema.py b/backends/vulkan/serialization/vulkan_graph_schema.py index aa7641bd927..cf5326f40cf 100644 --- a/backends/vulkan/serialization/vulkan_graph_schema.py +++ b/backends/vulkan/serialization/vulkan_graph_schema.py @@ -31,6 +31,7 @@ class VkDataType(IntEnum): FLOAT32 = 5 FLOAT64 = 6 INT64 = 7 + UNSET = 127 class VkStorageType(IntEnum): @@ -61,6 +62,7 @@ class VkTensor: mem_obj_id: int storage_type: VkStorageType = VkStorageType.DEFAULT_STORAGE memory_layout: VkMemoryLayout = VkMemoryLayout.DEFAULT_LAYOUT + staging_datatype: VkDataType = VkDataType.UNSET @dataclass diff --git a/backends/vulkan/test/utils/test_utils.cpp b/backends/vulkan/test/utils/test_utils.cpp index 07d28229221..f00dfa20976 100644 --- a/backends/vulkan/test/utils/test_utils.cpp +++ b/backends/vulkan/test/utils/test_utils.cpp @@ -44,6 +44,7 @@ vkapi::ShaderInfo get_nchw_to_tensor_shader( if (v_dst.storage_type() == utils::kBuffer) { kernel_name = "nchw_to_buffer"; add_dtype_suffix(kernel_name, v_dst.dtype()); + add_dtype_suffix(kernel_name, v_dst.dtype()); return VK_KERNEL_FROM_STR(kernel_name); } @@ -53,6 +54,7 @@ vkapi::ShaderInfo get_nchw_to_tensor_shader( } add_storage_type_suffix(kernel_name, v_dst.storage_type()); add_dtype_suffix(kernel_name, v_dst.dtype()); + add_dtype_suffix(kernel_name, v_dst.dtype()); return VK_KERNEL_FROM_STR(kernel_name); } @@ -78,6 +80,7 @@ vkapi::ShaderInfo get_tensor_to_nchw_shader( if (v_src.storage_type() == utils::kBuffer) { kernel_name = "buffer_to_nchw"; add_dtype_suffix(kernel_name, v_src.dtype()); + add_dtype_suffix(kernel_name, v_src.dtype()); return VK_KERNEL_FROM_STR(kernel_name); } @@ -87,6 +90,7 @@ vkapi::ShaderInfo get_tensor_to_nchw_shader( } add_storage_type_suffix(kernel_name, v_src.storage_type()); add_dtype_suffix(kernel_name, v_src.dtype()); + add_dtype_suffix(kernel_name, v_src.dtype()); return VK_KERNEL_FROM_STR(kernel_name); } diff --git a/backends/vulkan/vulkan_preprocess.py b/backends/vulkan/vulkan_preprocess.py index 69d3cdef75d..95da66494e0 100644 --- a/backends/vulkan/vulkan_preprocess.py +++ b/backends/vulkan/vulkan_preprocess.py @@ -112,6 +112,9 @@ def parse_compile_spec(compile_specs: List[CompileSpec]) -> Dict[str, Any]: if spec.key == "downcast_64_bit": options[spec.key] = bool.from_bytes(spec.value, byteorder="little") + if spec.key == "force_fp16": + options[spec.key] = bool.from_bytes(spec.value, byteorder="little") + # Unhandled options are ignored return options @@ -145,6 +148,7 @@ def preprocess( # noqa: C901 "memory_layout_override", VkMemoryLayout.TENSOR_WIDTH_PACKED ) downcast_64_bit = compile_options.get("downcast_64_bit", True) + force_fp16 = compile_options.get("force_fp16", False) program = unsafe_remove_auto_functionalized_pass(program) @@ -221,6 +225,7 @@ def preprocess( # noqa: C901 program, DelegateMappingBuilder(generated_identifiers=True), downcast_64_bit=downcast_64_bit, + force_fp16=force_fp16, ) vk_graph = graph_builder.build_graph() diff --git a/examples/vulkan/export.py b/examples/vulkan/export.py index b01bf7d37f3..4d85d83c862 100644 --- a/examples/vulkan/export.py +++ b/examples/vulkan/export.py @@ -50,6 +50,16 @@ def main() -> None: help=f"provide a model name. Valid ones: {list(MODEL_NAME_TO_MODEL.keys())}", ) + parser.add_argument( + "-fp16", + "--force_fp16", + action=argparse.BooleanOptionalAction, + default=False, + help="Force fp32 tensors to be converted to fp16 internally. Input/s outputs " + "will be converted to/from fp32 when entering/exiting the delegate. Default is " + "False", + ) + parser.add_argument( "-s", "--strict", @@ -126,6 +136,8 @@ def main() -> None: compile_options = {} if args.dynamic or dynamic_shapes is not None: compile_options["require_dynamic_shapes"] = True + if args.force_fp16: + compile_options["force_fp16"] = True # Configure Edge compilation edge_compile_config = EdgeCompileConfig( @@ -173,12 +185,22 @@ def main() -> None: # Save the program output_filename = f"{args.model_name}_vulkan" + atol = 1e-4 + rtol = 1e-4 + + # If forcing fp16, then numerical divergence is expected + if args.force_fp16: + atol = 2e-2 + rtol = 1e-1 + # Test the model if --test flag is provided if args.test: test_result = test_utils.run_and_check_output( reference_model=model, executorch_program=exec_prog, sample_inputs=example_inputs, + atol=atol, + rtol=rtol, ) if test_result: