diff --git a/backends/vulkan/runtime/VulkanBackend.cpp b/backends/vulkan/runtime/VulkanBackend.cpp index 7b138072d50..67b646ae1a8 100644 --- a/backends/vulkan/runtime/VulkanBackend.cpp +++ b/backends/vulkan/runtime/VulkanBackend.cpp @@ -86,6 +86,32 @@ vkapi::ScalarType get_scalar_type(const vkgraph::VkDataType& vk_datatype) { return vkapi::kFloat; case vkgraph::VkDataType::FLOAT64: return vkapi::kDouble; + default: + VK_THROW("Invalid VkDataType type encountered!"); + } +} + +vkapi::ScalarType equivalent_scalar_type( + const executorch::runtime::etensor::ScalarType& et_datatype) { + switch (et_datatype) { + case executorch::runtime::etensor::ScalarType::Byte: + return vkapi::kByte; + case executorch::runtime::etensor::ScalarType::Char: + return vkapi::kChar; + case executorch::runtime::etensor::ScalarType::Int: + return vkapi::kInt; + case executorch::runtime::etensor::ScalarType::Long: + return vkapi::kLong; + case executorch::runtime::etensor::ScalarType::Half: + return vkapi::kHalf; + case executorch::runtime::etensor::ScalarType::Float: + return vkapi::kFloat; + case executorch::runtime::etensor::ScalarType::Double: + return vkapi::kDouble; + case executorch::runtime::etensor::ScalarType::Bool: + return vkapi::kBool; + default: + VK_THROW("Invalid etensor::ScalarType encountered!"); } } @@ -343,6 +369,15 @@ class GraphBuilder { } } + vkapi::ScalarType get_staging_scalar_type_of(const uint32_t fb_id) { + VkTensorPtr tensor_fb = + flatbuffer_->values()->Get(fb_id)->value_as_VkTensor(); + if (tensor_fb->staging_datatype() == vkgraph::VkDataType::UNSET) { + return get_scalar_type(tensor_fb->datatype()); + } + return get_scalar_type(tensor_fb->staging_datatype()); + } + void build_graph() { // Resize the mapping to the number of values in the flatbuffer resize(flatbuffer_->values()->size()); @@ -359,7 +394,8 @@ class GraphBuilder { for (const uint32_t fb_id : *flatbuffer_->input_ids()) { const ValueRef ref = get_fb_id_valueref(fb_id); if (compute_graph_->val_is_tensor(ref)) { - compute_graph_->set_input_tensor(ref); + compute_graph_->set_input_tensor( + ref, get_staging_scalar_type_of(fb_id)); } else { compute_graph_->set_val_as_input(ref); } @@ -384,7 +420,12 @@ class GraphBuilder { // values as well if the source graph returns parameter nodes. for (const uint32_t fb_id : *flatbuffer_->output_ids()) { const ValueRef ref = get_fb_id_valueref(fb_id); - compute_graph_->set_output_value(ref); + if (compute_graph_->val_is_tensor(ref)) { + compute_graph_->set_output_tensor( + ref, get_staging_scalar_type_of(fb_id)); + } else { + compute_graph_->set_output_value(ref); + } } if (compute_graph_->graphconfig().enable_querypool) { @@ -582,10 +623,11 @@ class VulkanBackend final : public ::executorch::runtime::BackendInterface { bool was_resized = maybe_resize_input(compute_graph, i, args[i]->toTensor()); should_propagate_resize = should_propagate_resize || was_resized; - compute_graph->copy_into_staging( + compute_graph->maybe_cast_and_copy_into_staging( compute_graph->inputs()[i].staging, args[i]->toTensor().const_data_ptr(), - args[i]->toTensor().numel()); + args[i]->toTensor().numel(), + equivalent_scalar_type(args[i]->toTensor().scalar_type())); } else if (compute_graph->val_is_symint(iref)) { VK_CHECK_COND( args[i]->isTensor(), @@ -617,10 +659,11 @@ class VulkanBackend final : public ::executorch::runtime::BackendInterface { maybe_resize_output(compute_graph, i, args[o]->toTensor()); // args holds inputs directly followed by outputs, so the i'th output // for compute_graph corresponds to the o'th arg - compute_graph->copy_from_staging( + compute_graph->maybe_cast_and_copy_from_staging( compute_graph->outputs()[i].staging, args[o]->toTensor().mutable_data_ptr(), - args[o]->toTensor().numel()); + args[o]->toTensor().numel(), + equivalent_scalar_type(args[o]->toTensor().scalar_type())); } // TensorRef values represent constant tensors which will not have been // modified by the graph execution. Therefore, if a constant tensor is diff --git a/backends/vulkan/runtime/api/Context.cpp b/backends/vulkan/runtime/api/Context.cpp index 8599cbfffb6..adb8409d28c 100644 --- a/backends/vulkan/runtime/api/Context.cpp +++ b/backends/vulkan/runtime/api/Context.cpp @@ -117,6 +117,18 @@ void Context::check_device_capabilities(const vkapi::ShaderInfo& shader) { shader.kernel_name, vkapi::VulkanExtension::INTEGER_DOT_PRODUCT); } } + if (shader.requires_shader_int64) { + if (!adapter_p_->supports_int64_shader_types()) { + throw vkapi::ShaderNotSupportedError( + shader.kernel_name, vkapi::VulkanExtension::SHADER_INT64); + } + } + if (shader.requires_shader_float64) { + if (!adapter_p_->supports_float64_shader_types()) { + throw vkapi::ShaderNotSupportedError( + shader.kernel_name, vkapi::VulkanExtension::SHADER_FLOAT64); + } + } } vkapi::DescriptorSet Context::get_descriptor_set( diff --git a/backends/vulkan/runtime/api/containers/StagingBuffer.h b/backends/vulkan/runtime/api/containers/StagingBuffer.h index 1e9f569fc4a..6d0e5a4a457 100644 --- a/backends/vulkan/runtime/api/containers/StagingBuffer.h +++ b/backends/vulkan/runtime/api/containers/StagingBuffer.h @@ -48,7 +48,7 @@ class StagingBuffer final { context_p_->register_buffer_cleanup(vulkan_buffer_); } - inline vkapi::ScalarType dtype() { + inline vkapi::ScalarType dtype() const { return dtype_; } @@ -81,6 +81,15 @@ class StagingBuffer final { VK_WHOLE_SIZE); } + template + void cast_and_copy_from(const SRC_T* src, const size_t numel) { + VK_CHECK_COND(numel <= this->numel()); + DST_T* dst = reinterpret_cast(data()); + for (size_t i = 0; i < numel; ++i) { + dst[i] = static_cast(src[i]); + } + } + inline void copy_to(void* dst, const size_t nbytes) { VK_CHECK_COND(nbytes <= this->nbytes()); vmaInvalidateAllocation( @@ -91,6 +100,15 @@ class StagingBuffer final { memcpy(dst, data(), nbytes); } + template + void cast_and_copy_to(DST_T* dst, const size_t numel) { + VK_CHECK_COND(numel <= this->numel()); + const SRC_T* src = reinterpret_cast(data()); + for (size_t i = 0; i < numel; ++i) { + dst[i] = static_cast(src[i]); + } + } + inline void set_staging_zeros() { memset(data(), 0, nbytes()); } diff --git a/backends/vulkan/runtime/gen_vulkan_spv.py b/backends/vulkan/runtime/gen_vulkan_spv.py index 3f2d616b428..6db2e01d7e2 100644 --- a/backends/vulkan/runtime/gen_vulkan_spv.py +++ b/backends/vulkan/runtime/gen_vulkan_spv.py @@ -670,7 +670,7 @@ def addSrcAndYamlFiles(self, src_dir_paths: List[str]) -> None: if len(file) > 1: self.template_yaml_files.append(file) - def generateVariantCombinations( + def generateVariantCombinations( # noqa: C901 self, iterated_params: Dict[str, Any], exclude_params: Optional[Set[str]] = None, @@ -679,7 +679,25 @@ def generateVariantCombinations( exclude_params = set() all_iterated_params = [] for param_name, value_list in iterated_params.items(): - if param_name not in exclude_params: + if re.match(r"^combination\d*$", param_name): + param_values = [] + param_names = value_list["parameter_names"] + combos = value_list["combos"] + for combo in combos: + parameter_values = combo["parameter_values"] + if "suffix" in combo: + suffix = combo["suffix"] + else: + suffix = "" + for param_value in parameter_values: + if len(str(param_value)) > 0: + suffix += "_" + str(param_value) + suffix = suffix[1:] + param_values.append((param_names, suffix, parameter_values)) + + all_iterated_params.append(param_values) + + elif param_name not in exclude_params: param_values = [] for value in value_list: if "RANGE" in value: @@ -713,7 +731,7 @@ def generateVariantCombinations( return list(product(*all_iterated_params)) - def parseTemplateYaml(self, yaml_file: str) -> None: + def parseTemplateYaml(self, yaml_file: str) -> None: # noqa: C901 with open(yaml_file) as f: contents = yaml.load(f, Loader=UniqueKeyLoader) for template_name, params_dict in contents.items(): @@ -762,10 +780,21 @@ def parseTemplateYaml(self, yaml_file: str) -> None: default_params_copy[key] = variant[key] variant_name = variant["NAME"] - for param_value in combination: - default_params_copy[param_value[0]] = param_value[2] - if len(str(param_value[1])) > 0: - variant_name = f"{variant_name}_{param_value[1]}" + + for setting in combination: + param_names = setting[0] + suffix = setting[1] + param_values = setting[2] + if isinstance(param_names, list): + for param_name, param_value in zip( + param_names, param_values + ): + default_params_copy[param_name] = param_value + else: + default_params_copy[param_names] = param_values + + if len(str(suffix)) > 0: + variant_name = f"{variant_name}_{suffix}" default_params_copy["NAME"] = variant_name default_params_copy["VARIANT_NAME"] = variant["NAME"] @@ -1104,6 +1133,8 @@ class ShaderInfo: requires_16bit_storage_ext: bool = False requires_8bit_storage_ext: bool = False requires_integer_dot_product_ext: bool = False + requires_shader_int64_ext: bool = False + requires_shader_float64_ext: bool = False def getName(filePath: str) -> str: @@ -1193,7 +1224,7 @@ def determineDescriptorType(lineStr: str) -> str: ) -def getShaderInfo(srcFilePath: str) -> ShaderInfo: +def getShaderInfo(srcFilePath: str) -> ShaderInfo: # noqa: C901 shader_info = ShaderInfo([], [], "") with open(srcFilePath) as srcFile: for line in srcFile: @@ -1216,6 +1247,10 @@ def getShaderInfo(srcFilePath: str) -> ShaderInfo: shader_info.requires_8bit_storage_ext = True if "GL_EXT_integer_dot_product" in line: shader_info.requires_integer_dot_product_ext = True + if "GL_EXT_shader_explicit_arithmetic_types_int64" in line: + shader_info.requires_shader_int64_ext = True + if "GL_EXT_shader_explicit_arithmetic_types_float64" in line: + shader_info.requires_shader_float64_ext = True return shader_info @@ -1292,6 +1327,8 @@ def to_cpp_str(val: bool): to_cpp_str(shader_info.requires_16bit_storage_ext), to_cpp_str(shader_info.requires_8bit_storage_ext), to_cpp_str(shader_info.requires_integer_dot_product_ext), + to_cpp_str(shader_info.requires_shader_int64_ext), + to_cpp_str(shader_info.requires_shader_float64_ext), ] shader_info_str = textwrap.indent( diff --git a/backends/vulkan/runtime/graph/ComputeGraph.cpp b/backends/vulkan/runtime/graph/ComputeGraph.cpp index 6609298b0d8..2ec63a89df8 100644 --- a/backends/vulkan/runtime/graph/ComputeGraph.cpp +++ b/backends/vulkan/runtime/graph/ComputeGraph.cpp @@ -310,6 +310,8 @@ vkapi::ScalarType ComputeGraph::dtype_of(const ValueRef idx) const { return val.toConstTensor().dtype(); } else if (val.isTensorRef()) { return val.toConstTensorRef().dtype; + } else if (val.isStaging()) { + return val.toConstStaging().dtype(); } else if (val.isBool()) { return vkapi::ScalarType::Bool; } else if (val.isDouble()) { @@ -585,21 +587,45 @@ ValueRef ComputeGraph::get_or_add_value_for_int(const int64_t val) { return add_scalar(val); } +ValueRef ComputeGraph::set_input_tensor( + const ValueRef idx, + vkapi::ScalarType staging_dtype) { + // For texture storage, the buffer size needs to account for the zero + // padding applied by unused texel elements. + size_t buf_numel = get_tensor(idx)->staging_buffer_numel(); + ValueRef staging_idx = add_staging(staging_dtype, buf_numel); + add_staging_to_tensor_node(*this, staging_idx, idx); + inputs_.push_back({idx, staging_idx}); + return staging_idx; +} + ValueRef ComputeGraph::set_input_tensor( const ValueRef idx, const bool use_staging) { if (use_staging) { vkapi::ScalarType dtype = get_tensor(idx)->dtype(); - // For texture storage, the buffer size needs to account for the zero - // padding applied by unused texel elements. - size_t buf_numel = get_tensor(idx)->staging_buffer_numel(); - ValueRef staging_idx = add_staging(dtype, buf_numel); - add_staging_to_tensor_node(*this, staging_idx, idx); - inputs_.push_back({idx, staging_idx}); - return staging_idx; - } - inputs_.push_back({idx, kDummyValueRef}); - return idx; + return set_input_tensor(idx, dtype); + } else { + inputs_.push_back({idx, kDummyValueRef}); + return idx; + } +} + +ValueRef ComputeGraph::set_output_tensor( + const ValueRef idx, + vkapi::ScalarType staging_dtype) { + // For texture storage, the buffer size needs to account for the zero + // padding applied by unused texel elements. + size_t buf_numel = get_tensor(idx)->staging_buffer_numel(); + ValueRef staging_idx = add_staging(staging_dtype, buf_numel); + // We only run this when the tensor is non-empty. When the underlying + // tensor is empty (e.g. padded_numel == 0), we do not allocate a VkImage to + // tensor, we will not be able to bind the node for execution. + if (buf_numel > 0) { + add_tensor_to_staging_node(*this, idx, staging_idx); + } + outputs_.push_back({idx, staging_idx}); + return staging_idx; } ValueRef ComputeGraph::set_output_tensor( @@ -607,21 +633,11 @@ ValueRef ComputeGraph::set_output_tensor( const bool use_staging) { if (use_staging) { vkapi::ScalarType dtype = get_tensor(idx)->dtype(); - // For texture storage, the buffer size needs to account for the zero - // padding applied by unused texel elements. - size_t buf_numel = get_tensor(idx)->staging_buffer_numel(); - ValueRef staging_idx = add_staging(dtype, buf_numel); - // We only run this when the tensor is non-empty. When the underlying - // tensor is empty (e.g. padded_numel == 0), we do not allocate a VkImage to - // tensor, we will not be able to bind the node for execution. - if (buf_numel > 0) { - add_tensor_to_staging_node(*this, idx, staging_idx); - } - outputs_.push_back({idx, staging_idx}); - return staging_idx; + return set_output_tensor(idx, dtype); + } else { + outputs_.push_back({idx, kDummyValueRef}); + return idx; } - outputs_.push_back({idx, kDummyValueRef}); - return idx; } ValueRef ComputeGraph::set_output_value(const ValueRef idx) { @@ -847,6 +863,36 @@ void ComputeGraph::copy_into_staging( staging->copy_from(data, nbytes); } +void ComputeGraph::maybe_cast_and_copy_into_staging( + const ValueRef idx, + const void* data, + const size_t numel, + const vkapi::ScalarType src_data_dtype) { + StagingPtr staging = get_staging(idx); + vkapi::ScalarType staging_dtype = staging->dtype(); + if (src_data_dtype == staging_dtype) { + size_t nbytes = numel * vkapi::element_size(staging_dtype); + staging->copy_from(data, nbytes); + return; + } else { + // Hard-coded type conversion cases + if (src_data_dtype == vkapi::kLong && staging_dtype == vkapi::kInt) { + const int64_t* casted_data = reinterpret_cast(data); + staging->cast_and_copy_from(casted_data, numel); + } else if ( + src_data_dtype == vkapi::kDouble && staging_dtype == vkapi::kFloat) { + const double* casted_data = reinterpret_cast(data); + staging->cast_and_copy_from(casted_data, numel); + } else { + VK_THROW( + "Unsupported type conversion from ", + src_data_dtype, + " to staging dtype ", + staging_dtype); + } + } +} + void ComputeGraph::copy_from_staging( const ValueRef idx, void* data, @@ -856,6 +902,36 @@ void ComputeGraph::copy_from_staging( staging->copy_to(data, nbytes); } +void ComputeGraph::maybe_cast_and_copy_from_staging( + const ValueRef idx, + void* data, + const size_t numel, + const vkapi::ScalarType dst_data_dtype) { + StagingPtr staging = get_staging(idx); + vkapi::ScalarType staging_dtype = staging->dtype(); + if (dst_data_dtype == staging_dtype) { + size_t nbytes = numel * vkapi::element_size(staging_dtype); + staging->copy_to(data, nbytes); + return; + } else { + // Hard-coded type conversion cases + if (dst_data_dtype == vkapi::kLong && staging_dtype == vkapi::kInt) { + int64_t* casted_data = reinterpret_cast(data); + staging->cast_and_copy_to(casted_data, numel); + } else if ( + dst_data_dtype == vkapi::kDouble && staging_dtype == vkapi::kFloat) { + double* casted_data = reinterpret_cast(data); + staging->cast_and_copy_to(casted_data, numel); + } else { + VK_THROW( + "Unsupported type conversion from staging dtype ", + staging_dtype, + " to ", + dst_data_dtype); + } + } +} + void ComputeGraph::prepare() { #define MERGE_FIELD(field) \ static_cast(std::ceil( \ diff --git a/backends/vulkan/runtime/graph/ComputeGraph.h b/backends/vulkan/runtime/graph/ComputeGraph.h index 23b5517fd22..baa15233a00 100644 --- a/backends/vulkan/runtime/graph/ComputeGraph.h +++ b/backends/vulkan/runtime/graph/ComputeGraph.h @@ -771,7 +771,16 @@ class ComputeGraph final { */ ValueRef get_or_add_value_for_int(const int64_t val); + ValueRef set_input_tensor( + const ValueRef idx, + vkapi::ScalarType staging_dtype); + ValueRef set_input_tensor(const ValueRef idx, const bool use_staging = true); + + ValueRef set_output_tensor( + const ValueRef idx, + vkapi::ScalarType staging_dtype); + ValueRef set_output_tensor(const ValueRef idx, const bool use_staging = true); ValueRef set_output_value(const ValueRef idx); @@ -947,8 +956,21 @@ class ComputeGraph final { void copy_into_staging(const ValueRef idx, const void* data, const size_t numel); + + void maybe_cast_and_copy_into_staging( + const ValueRef idx, + const void* data, + const size_t numel, + const vkapi::ScalarType src_data_dtype); + void copy_from_staging(const ValueRef idx, void* data, const size_t numel); + void maybe_cast_and_copy_from_staging( + const ValueRef idx, + void* data, + const size_t numel, + const vkapi::ScalarType dst_data_dtype); + protected: // Command Buffer Management diff --git a/backends/vulkan/runtime/graph/ops/glsl/buffer_to_nchw.glsl b/backends/vulkan/runtime/graph/ops/glsl/buffer_to_nchw.glsl index 6d164ae2645..f61081d33b7 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/buffer_to_nchw.glsl +++ b/backends/vulkan/runtime/graph/ops/glsl/buffer_to_nchw.glsl @@ -3,14 +3,16 @@ #define PRECISION ${PRECISION} #define T ${buffer_scalar_type(DTYPE)} +#define DST_T ${buffer_scalar_type(BUF_DTYPE)} ${define_required_extensions(DTYPE)} +${define_required_extensions(BUF_DTYPE)} layout(std430) buffer; #include "indexing.glslh" -${layout_declare_tensor(B, "w", "nchw_buf", DTYPE, STORAGE)} +${layout_declare_tensor(B, "w", "nchw_buf", BUF_DTYPE, STORAGE)} ${layout_declare_tensor(B, "r", "t_inp", DTYPE, STORAGE)} ${layout_declare_ubo(B, "BufferMetadata", "inp")} @@ -32,5 +34,5 @@ void main() { uint nchwi = tensor_idx_to_contiguous_idx(inp, inp_tidx); - nchw_buf[nchwi] = t_inp[inp_bufi]; + nchw_buf[nchwi] = DST_T(t_inp[inp_bufi]); } diff --git a/backends/vulkan/runtime/graph/ops/glsl/buffer_to_nchw.yaml b/backends/vulkan/runtime/graph/ops/glsl/buffer_to_nchw.yaml index 929108cca5e..1ee7d2db8c1 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/buffer_to_nchw.yaml +++ b/backends/vulkan/runtime/graph/ops/glsl/buffer_to_nchw.yaml @@ -7,15 +7,19 @@ buffer_to_nchw: parameter_names_with_default_values: DTYPE: float + BUF_DTYPE: float STORAGE: buffer USE_PUSH_CONST: True generate_variant_forall: - DTYPE: - - VALUE: half - - VALUE: float - - VALUE: double - - VALUE: int8 - - VALUE: uint8 - - VALUE: int32 + combination: + parameter_names: [DTYPE, BUF_DTYPE] + combos: + - parameter_values: [half, half] + - parameter_values: [half, float] + - parameter_values: [float, float] + - parameter_values: [double, double] + - parameter_values: [int8, int8] + - parameter_values: [uint8, uint8] + - parameter_values: [int32, int32] shader_variants: - NAME: buffer_to_nchw diff --git a/backends/vulkan/runtime/graph/ops/glsl/image_to_nchw.glsl b/backends/vulkan/runtime/graph/ops/glsl/image_to_nchw.glsl index d7bef9f0163..1498ed01aef 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/image_to_nchw.glsl +++ b/backends/vulkan/runtime/graph/ops/glsl/image_to_nchw.glsl @@ -16,10 +16,11 @@ ${define_active_storage_type(STORAGE)} ${define_required_extensions(DTYPE)} +${define_required_extensions(BUF_DTYPE)} layout(std430) buffer; -${layout_declare_buffer(B, "w", "buf_out", DTYPE)} +${layout_declare_buffer(B, "w", "buf_out", BUF_DTYPE)} ${layout_declare_tensor(B, "r", "t_in", DTYPE, STORAGE)} $if USE_PUSH_CONST: diff --git a/backends/vulkan/runtime/graph/ops/glsl/image_to_nchw.yaml b/backends/vulkan/runtime/graph/ops/glsl/image_to_nchw.yaml index 646d8f1be81..ebbc55dd9dc 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/image_to_nchw.yaml +++ b/backends/vulkan/runtime/graph/ops/glsl/image_to_nchw.yaml @@ -7,17 +7,21 @@ image_to_nchw: parameter_names_with_default_values: DTYPE: float + BUF_DTYPE: float STORAGE: texture3d TO_STAGING: True USE_PUSH_CONST: True generate_variant_forall: - DTYPE: - - VALUE: half - - VALUE: float - - VALUE: double - - VALUE: int8 - - VALUE: uint8 - - VALUE: int32 + combination: + parameter_names: [DTYPE, BUF_DTYPE] + combos: + - parameter_values: [half, half] + - parameter_values: [half, float] + - parameter_values: [float, float] + - parameter_values: [double, double] + - parameter_values: [int8, int8] + - parameter_values: [uint8, uint8] + - parameter_values: [int32, int32] shader_variants: - NAME: image_to_nchw_texture3d - NAME: image_to_nchw_texture2d diff --git a/backends/vulkan/runtime/graph/ops/glsl/nchw_to_buffer.glsl b/backends/vulkan/runtime/graph/ops/glsl/nchw_to_buffer.glsl index 074624dc37e..a16f5405cbb 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/nchw_to_buffer.glsl +++ b/backends/vulkan/runtime/graph/ops/glsl/nchw_to_buffer.glsl @@ -5,13 +5,14 @@ #define T ${buffer_scalar_type(DTYPE)} ${define_required_extensions(DTYPE)} +${define_required_extensions(BUF_DTYPE)} layout(std430) buffer; #include "indexing.glslh" ${layout_declare_tensor(B, "w", "t_outp", DTYPE, STORAGE)} -${layout_declare_tensor(B, "r", "nchw_in", DTYPE, STORAGE)} +${layout_declare_tensor(B, "r", "nchw_in", BUF_DTYPE, STORAGE)} ${layout_declare_ubo(B, "BufferMetadata", "outp")} @@ -44,5 +45,5 @@ void main() { nchwi = tensor_idx_to_contiguous_idx(outp, outp_tidx); } - t_outp[outp_bufi] = nchw_in[nchwi]; + t_outp[outp_bufi] = T(nchw_in[nchwi]); } diff --git a/backends/vulkan/runtime/graph/ops/glsl/nchw_to_buffer.yaml b/backends/vulkan/runtime/graph/ops/glsl/nchw_to_buffer.yaml index 9d6c3aa76a9..602fd1bc65a 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/nchw_to_buffer.yaml +++ b/backends/vulkan/runtime/graph/ops/glsl/nchw_to_buffer.yaml @@ -7,15 +7,19 @@ nchw_to_buffer: parameter_names_with_default_values: DTYPE: float + BUF_DTYPE: float STORAGE: buffer USE_PUSH_CONST: True generate_variant_forall: - DTYPE: - - VALUE: half - - VALUE: float - - VALUE: double - - VALUE: int8 - - VALUE: uint8 - - VALUE: int32 + combination: + parameter_names: [DTYPE, BUF_DTYPE] + combos: + - parameter_values: [half, half] + - parameter_values: [half, float] + - parameter_values: [float, float] + - parameter_values: [double, double] + - parameter_values: [int8, int8] + - parameter_values: [uint8, uint8] + - parameter_values: [int32, int32] shader_variants: - NAME: nchw_to_buffer diff --git a/backends/vulkan/runtime/graph/ops/glsl/nchw_to_image.glsl b/backends/vulkan/runtime/graph/ops/glsl/nchw_to_image.glsl index f3f604e10cd..15676fb0500 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/nchw_to_image.glsl +++ b/backends/vulkan/runtime/graph/ops/glsl/nchw_to_image.glsl @@ -16,11 +16,12 @@ ${define_active_storage_type(STORAGE)} ${define_required_extensions(DTYPE)} +${define_required_extensions(BUF_DTYPE)} layout(std430) buffer; ${layout_declare_tensor(B, "w", "t_out", DTYPE, STORAGE)} -${layout_declare_buffer(B, "r", "buf_in", DTYPE)} +${layout_declare_buffer(B, "r", "buf_in", BUF_DTYPE)} $if USE_PUSH_CONST: layout(push_constant) uniform restrict Block { diff --git a/backends/vulkan/runtime/graph/ops/glsl/nchw_to_image.yaml b/backends/vulkan/runtime/graph/ops/glsl/nchw_to_image.yaml index 85119c8d508..f6809e4024a 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/nchw_to_image.yaml +++ b/backends/vulkan/runtime/graph/ops/glsl/nchw_to_image.yaml @@ -11,13 +11,16 @@ nchw_to_image: FROM_STAGING: True USE_PUSH_CONST: True generate_variant_forall: - DTYPE: - - VALUE: half - - VALUE: float - - VALUE: double - - VALUE: int8 - - VALUE: uint8 - - VALUE: int32 + combination: + parameter_names: [DTYPE, BUF_DTYPE] + combos: + - parameter_values: [half, half] + - parameter_values: [half, float] + - parameter_values: [float, float] + - parameter_values: [double, double] + - parameter_values: [int8, int8] + - parameter_values: [uint8, uint8] + - parameter_values: [int32, int32] shader_variants: - NAME: nchw_to_image_texture3d - NAME: nchw_to_image_texture2d diff --git a/backends/vulkan/runtime/graph/ops/impl/Clone.cpp b/backends/vulkan/runtime/graph/ops/impl/Clone.cpp index 0ae9d53a481..059aade5b04 100644 --- a/backends/vulkan/runtime/graph/ops/impl/Clone.cpp +++ b/backends/vulkan/runtime/graph/ops/impl/Clone.cpp @@ -76,6 +76,7 @@ void add_image_to_buffer_node( const ValueRef buffer) { std::string kernel_name = "clone_image_to_buffer"; add_dtype_suffix(kernel_name, graph.dtype_of(image)); + add_dtype_suffix(kernel_name, graph.dtype_of(buffer)); vkapi::ShaderInfo shader = VK_KERNEL_FROM_STR(kernel_name); graph.execute_nodes().emplace_back(new DynamicDispatchNode( @@ -103,6 +104,7 @@ void add_buffer_to_image_node( const ValueRef image) { std::string kernel_name = "clone_buffer_to_image"; add_dtype_suffix(kernel_name, graph.dtype_of(image)); + add_dtype_suffix(kernel_name, graph.dtype_of(buffer)); vkapi::ShaderInfo shader = VK_KERNEL_FROM_STR(kernel_name); graph.execute_nodes().emplace_back(new DynamicDispatchNode( diff --git a/backends/vulkan/runtime/graph/ops/impl/Convolution.cpp b/backends/vulkan/runtime/graph/ops/impl/Convolution.cpp index ded1defe973..b83164f27d2 100644 --- a/backends/vulkan/runtime/graph/ops/impl/Convolution.cpp +++ b/backends/vulkan/runtime/graph/ops/impl/Convolution.cpp @@ -105,7 +105,8 @@ ValueRef prepack_biases( ValueRef v = graph.add_tensor( {out_channels}, graph.dtype_of(weight), storage_type, memory_layout); - vkapi::ShaderInfo shader = get_nchw_to_tensor_shader(graph, v); + vkapi::ShaderInfo shader = + get_nchw_to_tensor_shader(graph, v, graph.dtype_of(weight)); graph.prepack_nodes().emplace_back(new PrepackNode( graph, diff --git a/backends/vulkan/runtime/graph/ops/impl/Staging.cpp b/backends/vulkan/runtime/graph/ops/impl/Staging.cpp index 648d7b8da09..40de9b59e81 100644 --- a/backends/vulkan/runtime/graph/ops/impl/Staging.cpp +++ b/backends/vulkan/runtime/graph/ops/impl/Staging.cpp @@ -27,7 +27,10 @@ void add_staging_to_tensor_node( VK_CHECK_COND(graph.val_is_staging(in_staging)); vkapi::ShaderInfo shader = get_nchw_to_tensor_shader( - graph, out_tensor, graph.int8_buffers_enabled()); + graph, + out_tensor, + graph.dtype_of(in_staging), + graph.int8_buffers_enabled()); vkapi::ParamsBindList param_buffers = {}; if (graph.is_buffer_storage(out_tensor)) { @@ -66,16 +69,6 @@ bool is_bitw8_shader(const vkapi::ShaderInfo& shader) { return shader_prefix_str == kBitw8PrefixStr; } -vkapi::ShaderInfo get_tensor_to_staging_shader( - ComputeGraph* graph, - const std::vector& args, - const std::vector& resize_args) { - (void)resize_args; - const ValueRef in_tensor = args.at(1).refs.at(0); - return get_tensor_to_nchw_shader( - *graph, in_tensor, graph->int8_buffers_enabled()); -} - utils::uvec3 tensor_to_staging_global_wg_size( ComputeGraph* graph, const vkapi::ShaderInfo& shader, @@ -110,8 +103,11 @@ void add_tensor_to_staging_node( const ValueRef out_staging) { VK_CHECK_COND(graph.val_is_staging(out_staging)); - vkapi::ShaderInfo shader = - get_tensor_to_nchw_shader(graph, in_tensor, graph.int8_buffers_enabled()); + vkapi::ShaderInfo shader = get_tensor_to_nchw_shader( + graph, + in_tensor, + graph.dtype_of(out_staging), + graph.int8_buffers_enabled()); vkapi::ParamsBindList param_buffers = {}; if (graph.is_buffer_storage(in_tensor)) { @@ -151,8 +147,8 @@ void add_prepack_standard_node( const ValueRef tensor_data, const ValueRef tensor, const bool transpose_hw = false) { - vkapi::ShaderInfo shader = - get_nchw_to_tensor_shader(graph, tensor, graph.int8_buffers_enabled()); + vkapi::ShaderInfo shader = get_nchw_to_tensor_shader( + graph, tensor, graph.dtype_of(tensor_data), graph.int8_buffers_enabled()); vkapi::ParamsBindList param_buffers = {}; if (graph.is_buffer_storage(tensor)) { diff --git a/backends/vulkan/runtime/graph/ops/utils/StagingUtils.cpp b/backends/vulkan/runtime/graph/ops/utils/StagingUtils.cpp index c90bfa402bb..c2adca526fb 100644 --- a/backends/vulkan/runtime/graph/ops/utils/StagingUtils.cpp +++ b/backends/vulkan/runtime/graph/ops/utils/StagingUtils.cpp @@ -23,6 +23,7 @@ bool is_bitw8(vkapi::ScalarType dtype) { vkapi::ShaderInfo get_nchw_to_tensor_shader( ComputeGraph& graph, const ValueRef dst, + const vkapi::ScalarType staging_dtype, bool int8_buffer_enabled, bool push_constant_variant) { std::string kernel_name; @@ -45,6 +46,7 @@ vkapi::ShaderInfo get_nchw_to_tensor_shader( if (dst_storage_type == utils::kBuffer) { kernel_name = "nchw_to_buffer"; add_dtype_suffix(kernel_name, dst_dtype); + add_dtype_suffix(kernel_name, staging_dtype); return VK_KERNEL_FROM_STR(kernel_name); } @@ -54,6 +56,7 @@ vkapi::ShaderInfo get_nchw_to_tensor_shader( } add_storage_type_suffix(kernel_name, dst_storage_type); add_dtype_suffix(kernel_name, dst_dtype); + add_dtype_suffix(kernel_name, staging_dtype); return VK_KERNEL_FROM_STR(kernel_name); } @@ -61,6 +64,7 @@ vkapi::ShaderInfo get_nchw_to_tensor_shader( vkapi::ShaderInfo get_tensor_to_nchw_shader( ComputeGraph& graph, const ValueRef src, + const vkapi::ScalarType staging_dtype, bool int8_buffer_enabled, bool push_constant_variant) { std::string kernel_name; @@ -83,6 +87,7 @@ vkapi::ShaderInfo get_tensor_to_nchw_shader( if (src_storage_type == utils::kBuffer) { kernel_name = "buffer_to_nchw"; add_dtype_suffix(kernel_name, src_dtype); + add_dtype_suffix(kernel_name, staging_dtype); return VK_KERNEL_FROM_STR(kernel_name); } @@ -92,6 +97,7 @@ vkapi::ShaderInfo get_tensor_to_nchw_shader( } add_storage_type_suffix(kernel_name, src_storage_type); add_dtype_suffix(kernel_name, src_dtype); + add_dtype_suffix(kernel_name, staging_dtype); return VK_KERNEL_FROM_STR(kernel_name); } diff --git a/backends/vulkan/runtime/graph/ops/utils/StagingUtils.h b/backends/vulkan/runtime/graph/ops/utils/StagingUtils.h index 71c92b833b7..a4419de3932 100644 --- a/backends/vulkan/runtime/graph/ops/utils/StagingUtils.h +++ b/backends/vulkan/runtime/graph/ops/utils/StagingUtils.h @@ -15,11 +15,14 @@ namespace vkcompute { vkapi::ShaderInfo get_nchw_to_tensor_shader( ComputeGraph& graph, const ValueRef dst, + const vkapi::ScalarType staging_dtype, bool int8_buffer_enabled = true, bool push_constant_variant = true); + vkapi::ShaderInfo get_tensor_to_nchw_shader( ComputeGraph& graph, const ValueRef src, + const vkapi::ScalarType staging_dtype, bool int8_buffer_enabled = true, bool push_constant_variant = true); diff --git a/backends/vulkan/runtime/vk_api/Adapter.cpp b/backends/vulkan/runtime/vk_api/Adapter.cpp index 0e87dde1922..aa76b202882 100644 --- a/backends/vulkan/runtime/vk_api/Adapter.cpp +++ b/backends/vulkan/runtime/vk_api/Adapter.cpp @@ -11,6 +11,7 @@ #include #include +#include namespace vkcompute { namespace vkapi { @@ -412,6 +413,11 @@ std::string Adapter::stringize() const { #endif /* VK_KHR_shader_float16_int8 */ ss << " }" << std::endl; + ss << " Shader 64bit Features {" << std::endl; + PRINT_BOOL(physical_device_.supports_int64_shader_types, shaderInt64) + PRINT_BOOL(physical_device_.supports_float64_shader_types, shaderFloat64) + ss << " }" << std::endl; + #ifdef VK_KHR_shader_integer_dot_product ss << " Shader Integer Dot Product Features {" << std::endl; PRINT_PROP( diff --git a/backends/vulkan/runtime/vk_api/Adapter.h b/backends/vulkan/runtime/vk_api/Adapter.h index 6a68b487348..65d0977b533 100644 --- a/backends/vulkan/runtime/vk_api/Adapter.h +++ b/backends/vulkan/runtime/vk_api/Adapter.h @@ -225,6 +225,14 @@ class Adapter final { return physical_device_.supports_int16_shader_types; } + inline bool supports_int64_shader_types() { + return physical_device_.supports_int64_shader_types; + } + + inline bool supports_float64_shader_types() { + return physical_device_.supports_float64_shader_types; + } + inline bool has_full_float16_buffers_support() { return supports_16bit_storage_buffers() && supports_float16_shader_types(); } diff --git a/backends/vulkan/runtime/vk_api/Device.cpp b/backends/vulkan/runtime/vk_api/Device.cpp index a21130f1231..7a3a825f5ec 100644 --- a/backends/vulkan/runtime/vk_api/Device.cpp +++ b/backends/vulkan/runtime/vk_api/Device.cpp @@ -16,6 +16,7 @@ #include #include #include +#include namespace vkcompute { namespace vkapi { @@ -45,6 +46,8 @@ PhysicalDevice::PhysicalDevice(VkPhysicalDevice physical_device_handle) queue_families{}, num_compute_queues(0), supports_int16_shader_types(false), + supports_int64_shader_types(false), + supports_float64_shader_types(false), has_unified_memory(false), has_timestamps(false), timestamp_period(0), @@ -97,6 +100,12 @@ PhysicalDevice::PhysicalDevice(VkPhysicalDevice physical_device_handle) if (features2.features.shaderInt16 == VK_TRUE) { supports_int16_shader_types = true; } + if (features2.features.shaderInt64 == VK_TRUE) { + supports_int64_shader_types = true; + } + if (features2.features.shaderFloat64 == VK_TRUE) { + supports_float64_shader_types = true; + } // Check if there are any memory types have both the HOST_VISIBLE and the // DEVICE_LOCAL property flags diff --git a/backends/vulkan/runtime/vk_api/Device.h b/backends/vulkan/runtime/vk_api/Device.h index f5b7154d260..917df514c4b 100644 --- a/backends/vulkan/runtime/vk_api/Device.h +++ b/backends/vulkan/runtime/vk_api/Device.h @@ -12,7 +12,7 @@ #include -#include +#include #include namespace vkcompute { @@ -57,6 +57,8 @@ struct PhysicalDevice final { // Metadata uint32_t num_compute_queues; bool supports_int16_shader_types; + bool supports_int64_shader_types; + bool supports_float64_shader_types; bool has_unified_memory; bool has_timestamps; float timestamp_period; diff --git a/backends/vulkan/runtime/vk_api/Exception.cpp b/backends/vulkan/runtime/vk_api/Exception.cpp index c07349fa7ca..d3efa81e52a 100644 --- a/backends/vulkan/runtime/vk_api/Exception.cpp +++ b/backends/vulkan/runtime/vk_api/Exception.cpp @@ -95,6 +95,12 @@ std::ostream& operator<<(std::ostream& out, const VulkanExtension result) { case VulkanExtension::INTEGER_DOT_PRODUCT: out << "VK_KHR_shader_integer_dot_product"; break; + case VulkanExtension::SHADER_INT64: + out << "shaderInt64"; + break; + case VulkanExtension::SHADER_FLOAT64: + out << "shaderFloat64"; + break; } return out; } diff --git a/backends/vulkan/runtime/vk_api/Exception.h b/backends/vulkan/runtime/vk_api/Exception.h index a883a68fefc..aa1ef1f2526 100644 --- a/backends/vulkan/runtime/vk_api/Exception.h +++ b/backends/vulkan/runtime/vk_api/Exception.h @@ -83,6 +83,8 @@ enum class VulkanExtension : uint8_t { INT16_STORAGE, INT8_STORAGE, INTEGER_DOT_PRODUCT, + SHADER_INT64, + SHADER_FLOAT64, }; class ShaderNotSupportedError : public std::exception { diff --git a/backends/vulkan/runtime/vk_api/Shader.cpp b/backends/vulkan/runtime/vk_api/Shader.cpp index 4356f92efe7..c932d0a264b 100644 --- a/backends/vulkan/runtime/vk_api/Shader.cpp +++ b/backends/vulkan/runtime/vk_api/Shader.cpp @@ -32,7 +32,9 @@ ShaderInfo::ShaderInfo( const bool requires_shader_int16_ext, const bool requires_16bit_storage_ext, const bool requires_8bit_storage_ext, - const bool requires_integer_dot_product_ext) + const bool requires_integer_dot_product_ext, + const bool requires_shader_int64_ext, + const bool requires_shader_float64_ext) : src_code{ spirv_bin, size, @@ -43,7 +45,9 @@ ShaderInfo::ShaderInfo( requires_shader_int16(requires_shader_int16_ext), requires_16bit_storage(requires_16bit_storage_ext), requires_8bit_storage(requires_8bit_storage_ext), - requires_integer_dot_product(requires_integer_dot_product_ext) { + requires_integer_dot_product(requires_integer_dot_product_ext), + requires_shader_int64(requires_shader_int64_ext), + requires_shader_float64(requires_shader_float64_ext) { } bool operator==(const ShaderInfo& _1, const ShaderInfo& _2) { diff --git a/backends/vulkan/runtime/vk_api/Shader.h b/backends/vulkan/runtime/vk_api/Shader.h index 21332381406..6311710f02b 100644 --- a/backends/vulkan/runtime/vk_api/Shader.h +++ b/backends/vulkan/runtime/vk_api/Shader.h @@ -66,6 +66,8 @@ struct ShaderInfo final { bool requires_16bit_storage = false; bool requires_8bit_storage = false; bool requires_integer_dot_product = false; + bool requires_shader_int64 = false; + bool requires_shader_float64 = false; explicit ShaderInfo(); @@ -78,7 +80,9 @@ struct ShaderInfo final { const bool requires_shader_int16_ext, const bool requires_16bit_storage_ext, const bool requires_8bit_storage_ext, - const bool requires_integer_dot_product_ext); + const bool requires_integer_dot_product_ext, + const bool requires_shader_int64_ext, + const bool requires_shader_float64_ext); operator bool() const { return src_code.bin != nullptr; diff --git a/backends/vulkan/serialization/schema.fbs b/backends/vulkan/serialization/schema.fbs index b6670b6f53d..4bc12208ce7 100644 --- a/backends/vulkan/serialization/schema.fbs +++ b/backends/vulkan/serialization/schema.fbs @@ -20,6 +20,7 @@ enum VkDataType : byte { FLOAT32 = 5, FLOAT64 = 6, INT64 = 7, + UNSET = 127, } // Describes what kind of GPU resource should be used to represent a tensor. The @@ -55,6 +56,9 @@ table VkTensor { storage_type:VkStorageType = DEFAULT_STORAGE; // Memory layout that should be used to represent this tensor memory_layout:VkMemoryLayout = DEFAULT_LAYOUT; + // dtype to use for staging buffer. This may be different from the tensor's datatype + // if force_fp16 is enabled to force all float tensors to be represented as fp16. + staging_datatype:VkDataType = UNSET; } table Null {} diff --git a/backends/vulkan/serialization/vulkan_graph_builder.py b/backends/vulkan/serialization/vulkan_graph_builder.py index 78ac51c8808..43ea6c7ce30 100644 --- a/backends/vulkan/serialization/vulkan_graph_builder.py +++ b/backends/vulkan/serialization/vulkan_graph_builder.py @@ -50,10 +50,12 @@ def __init__( program: ExportedProgram, delegate_mapping_builder: DelegateMappingBuilder, downcast_64_bit: bool = True, + force_fp16: bool = False, ) -> None: self.program = program self.delegate_mapping_builder = delegate_mapping_builder self.downcast_64_bit = downcast_64_bit + self.force_fp16 = force_fp16 self.chain = [] self.values = [] self.input_ids = [] @@ -135,6 +137,12 @@ def maybe_add_constant_tensor(self, node: Node) -> int: if is_param_node(self.program, node): tensor = self.get_param_tensor(node) + effective_dtype = self.get_effective_dtype(tensor.dtype) + + # Convert the tensor dtype if needed + if tensor.dtype != effective_dtype: + tensor = tensor.to(effective_dtype) + # Serialize tensor data to bytes tensor = tensor.contiguous() size = tensor.untyped_storage().nbytes() @@ -222,6 +230,29 @@ def create_symint_value(self) -> int: self.values.append(vk_graph_schema.VkValue(vk_graph_schema.SymInt(0))) return new_id + def get_effective_dtype(self, dtype: torch.dtype) -> torch.dtype: + if self.downcast_64_bit and dtype == torch.float64: + return torch.float32 + elif self.downcast_64_bit and dtype == torch.int64: + return torch.int32 + elif self.force_fp16 and dtype == torch.float32: + return torch.float16 + else: + return dtype + + def get_staging_dtype(self, dtype: torch.dtype) -> torch.dtype: + # Since 64 bit types are not guaranteed to be supported on all GPUs, + # the conversion between 32 bit and 64 bit types is handled on the CPU + # side. The conversion will occur when copying the staging buffer + # contents to/from ETensor data pointers, rather than in the shader to + # copy between GPU buffer/image to staging buffer. + if self.downcast_64_bit and dtype == torch.float64: + return torch.float32 + elif self.downcast_64_bit and dtype == torch.int64: + return torch.int32 + else: + return dtype + def create_tensor_value(self, spec: TensorSpec, constant_id: int = -1) -> int: # Negative id indicates that this tensor will have its own dedicated memory. mem_obj_id = -1 @@ -236,14 +267,16 @@ def create_tensor_value(self, spec: TensorSpec, constant_id: int = -1) -> int: storage_type = spec.etvk_node_repr.storage_type memory_layout = spec.etvk_node_repr.memory_layout - # Apply downcast logic before getting VK datatype - effective_dtype = spec.dtype - if self.downcast_64_bit and spec.dtype == torch.float64: - effective_dtype = torch.float32 - elif self.downcast_64_bit and spec.dtype == torch.int64: - effective_dtype = torch.int32 + effective_dtype = self.get_effective_dtype(spec.dtype) + # For constant tensors, the datatype of the original tensor will have been + # converted to the effective dtype. Otherwise, the type of the staging buffer + # for inputs/outputs should match the original tensor dtype. + staging_dtype = ( + effective_dtype if constant_id >= 0 else self.get_staging_dtype(spec.dtype) + ) datatype = self.get_vk_datatype(effective_dtype) + staging_datatype = self.get_vk_datatype(staging_dtype) new_id = len(self.values) self.values.append( @@ -255,6 +288,7 @@ def create_tensor_value(self, spec: TensorSpec, constant_id: int = -1) -> int: mem_obj_id=mem_obj_id, storage_type=storage_type, memory_layout=memory_layout, + staging_datatype=staging_datatype, ) ) ) diff --git a/backends/vulkan/serialization/vulkan_graph_schema.py b/backends/vulkan/serialization/vulkan_graph_schema.py index aa7641bd927..cf5326f40cf 100644 --- a/backends/vulkan/serialization/vulkan_graph_schema.py +++ b/backends/vulkan/serialization/vulkan_graph_schema.py @@ -31,6 +31,7 @@ class VkDataType(IntEnum): FLOAT32 = 5 FLOAT64 = 6 INT64 = 7 + UNSET = 127 class VkStorageType(IntEnum): @@ -61,6 +62,7 @@ class VkTensor: mem_obj_id: int storage_type: VkStorageType = VkStorageType.DEFAULT_STORAGE memory_layout: VkMemoryLayout = VkMemoryLayout.DEFAULT_LAYOUT + staging_datatype: VkDataType = VkDataType.UNSET @dataclass diff --git a/backends/vulkan/test/utils/test_utils.cpp b/backends/vulkan/test/utils/test_utils.cpp index 07d28229221..f00dfa20976 100644 --- a/backends/vulkan/test/utils/test_utils.cpp +++ b/backends/vulkan/test/utils/test_utils.cpp @@ -44,6 +44,7 @@ vkapi::ShaderInfo get_nchw_to_tensor_shader( if (v_dst.storage_type() == utils::kBuffer) { kernel_name = "nchw_to_buffer"; add_dtype_suffix(kernel_name, v_dst.dtype()); + add_dtype_suffix(kernel_name, v_dst.dtype()); return VK_KERNEL_FROM_STR(kernel_name); } @@ -53,6 +54,7 @@ vkapi::ShaderInfo get_nchw_to_tensor_shader( } add_storage_type_suffix(kernel_name, v_dst.storage_type()); add_dtype_suffix(kernel_name, v_dst.dtype()); + add_dtype_suffix(kernel_name, v_dst.dtype()); return VK_KERNEL_FROM_STR(kernel_name); } @@ -78,6 +80,7 @@ vkapi::ShaderInfo get_tensor_to_nchw_shader( if (v_src.storage_type() == utils::kBuffer) { kernel_name = "buffer_to_nchw"; add_dtype_suffix(kernel_name, v_src.dtype()); + add_dtype_suffix(kernel_name, v_src.dtype()); return VK_KERNEL_FROM_STR(kernel_name); } @@ -87,6 +90,7 @@ vkapi::ShaderInfo get_tensor_to_nchw_shader( } add_storage_type_suffix(kernel_name, v_src.storage_type()); add_dtype_suffix(kernel_name, v_src.dtype()); + add_dtype_suffix(kernel_name, v_src.dtype()); return VK_KERNEL_FROM_STR(kernel_name); } diff --git a/backends/vulkan/vulkan_preprocess.py b/backends/vulkan/vulkan_preprocess.py index 69d3cdef75d..95da66494e0 100644 --- a/backends/vulkan/vulkan_preprocess.py +++ b/backends/vulkan/vulkan_preprocess.py @@ -112,6 +112,9 @@ def parse_compile_spec(compile_specs: List[CompileSpec]) -> Dict[str, Any]: if spec.key == "downcast_64_bit": options[spec.key] = bool.from_bytes(spec.value, byteorder="little") + if spec.key == "force_fp16": + options[spec.key] = bool.from_bytes(spec.value, byteorder="little") + # Unhandled options are ignored return options @@ -145,6 +148,7 @@ def preprocess( # noqa: C901 "memory_layout_override", VkMemoryLayout.TENSOR_WIDTH_PACKED ) downcast_64_bit = compile_options.get("downcast_64_bit", True) + force_fp16 = compile_options.get("force_fp16", False) program = unsafe_remove_auto_functionalized_pass(program) @@ -221,6 +225,7 @@ def preprocess( # noqa: C901 program, DelegateMappingBuilder(generated_identifiers=True), downcast_64_bit=downcast_64_bit, + force_fp16=force_fp16, ) vk_graph = graph_builder.build_graph() diff --git a/examples/vulkan/export.py b/examples/vulkan/export.py index b01bf7d37f3..4d85d83c862 100644 --- a/examples/vulkan/export.py +++ b/examples/vulkan/export.py @@ -50,6 +50,16 @@ def main() -> None: help=f"provide a model name. Valid ones: {list(MODEL_NAME_TO_MODEL.keys())}", ) + parser.add_argument( + "-fp16", + "--force_fp16", + action=argparse.BooleanOptionalAction, + default=False, + help="Force fp32 tensors to be converted to fp16 internally. Input/s outputs " + "will be converted to/from fp32 when entering/exiting the delegate. Default is " + "False", + ) + parser.add_argument( "-s", "--strict", @@ -126,6 +136,8 @@ def main() -> None: compile_options = {} if args.dynamic or dynamic_shapes is not None: compile_options["require_dynamic_shapes"] = True + if args.force_fp16: + compile_options["force_fp16"] = True # Configure Edge compilation edge_compile_config = EdgeCompileConfig( @@ -173,12 +185,22 @@ def main() -> None: # Save the program output_filename = f"{args.model_name}_vulkan" + atol = 1e-4 + rtol = 1e-4 + + # If forcing fp16, then numerical divergence is expected + if args.force_fp16: + atol = 2e-2 + rtol = 1e-1 + # Test the model if --test flag is provided if args.test: test_result = test_utils.run_and_check_output( reference_model=model, executorch_program=exec_prog, sample_inputs=example_inputs, + atol=atol, + rtol=rtol, ) if test_result: