From 769014f6b6c5abf959067ef5ca6fdba56e6859a0 Mon Sep 17 00:00:00 2001 From: adrianlizarraga Date: Tue, 13 May 2025 01:01:56 -0700 Subject: [PATCH 01/14] Compile API: add suport for OrtModel input --- .../core/session/onnxruntime_c_api.h | 12 +++ .../core/session/onnxruntime_cxx_api.h | 66 ++++++------- .../core/session/onnxruntime_cxx_inline.h | 6 ++ .../core/graph/model_editor_api_types.h | 1 + onnxruntime/core/session/compile_api.cc | 23 +++++ onnxruntime/core/session/compile_api.h | 2 + .../core/session/model_compilation_options.cc | 95 +++++++++---------- .../core/session/model_compilation_options.h | 51 +++++----- .../core/session/model_editor_c_api.cc | 7 +- onnxruntime/core/session/utils.cc | 36 ++++++- onnxruntime/core/session/utils.h | 16 +++- .../test/providers/qnn/qnn_ep_context_test.cc | 63 ++++++++++++ 12 files changed, 260 insertions(+), 118 deletions(-) diff --git a/include/onnxruntime/core/session/onnxruntime_c_api.h b/include/onnxruntime/core/session/onnxruntime_c_api.h index 25b6d72394e0c..c0cc59739e076 100644 --- a/include/onnxruntime/core/session/onnxruntime_c_api.h +++ b/include/onnxruntime/core/session/onnxruntime_c_api.h @@ -5964,6 +5964,18 @@ struct OrtCompileApi { * \since Version 1.22. */ ORT_API2_STATUS(CompileModel, _In_ const OrtEnv* env, _In_ const OrtModelCompilationOptions* model_options); + + /** \brief Sets the input OrtModel instance to compile. + * + * \param[in] model_compile_options The OrtModelCompilationOptions instance. + * \param[in] input_model The OrtModel instance of the model to compile. + * + * \snippet{doc} snippets.dox OrtStatus Return Value + * + * \since Version 1.23. + */ + ORT_API2_STATUS(ModelCompilationOptions_SetInputModel, _In_ OrtModelCompilationOptions* model_compile_options, + _In_ const OrtModel* input_model); }; ORT_RUNTIME_CLASS(Ep); diff --git a/include/onnxruntime/core/session/onnxruntime_cxx_api.h b/include/onnxruntime/core/session/onnxruntime_cxx_api.h index 8876c40fe9e6c..f21c092efb220 100644 --- a/include/onnxruntime/core/session/onnxruntime_cxx_api.h +++ b/include/onnxruntime/core/session/onnxruntime_cxx_api.h @@ -1138,38 +1138,6 @@ struct SessionOptions : detail::SessionOptionsImpl { ConstSessionOptions GetConst() const { return ConstSessionOptions{this->p_}; } }; -/** \brief Options object used when compiling a model. - * - * Wraps ::OrtModelCompilationOptions object and methods - */ -struct ModelCompilationOptions : detail::Base { - using Base = detail::Base; - using Base::Base; - - explicit ModelCompilationOptions(std::nullptr_t) {} ///< Create an empty ModelCompilationOptions object, must be assigned a valid one to be used. - - ModelCompilationOptions(const Env& env, const SessionOptions& session_options); ///< Wraps OrtApi::CreateModelCompilationOptionsFromSessionOptions - ModelCompilationOptions(const Env& env, ConstSessionOptions session_options); ///< Wraps OrtApi::CreateModelCompilationOptionsFromSessionOptions - - ModelCompilationOptions& SetInputModelPath(const ORTCHAR_T* input_model_path); ///< Wraps OrtApi::ModelCompilationOptions_SetInputModelPath - ModelCompilationOptions& SetInputModelFromBuffer(const void* input_model_data, - size_t input_model_data_size); ///< Wraps OrtApi::ModelCompilationOptions_SetInputModelFromBuffer - ModelCompilationOptions& SetEpContextEmbedMode(bool embed_ep_context_in_model); ///< Wraps OrtApi::ModelCompilationOptions_SetEpContextEmbedMode - ModelCompilationOptions& SetOutputModelPath(const ORTCHAR_T* output_model_path); ///< Wraps OrtApi::ModelCompilationOptions_SetOutputModelPath - ModelCompilationOptions& SetOutputModelExternalInitializersFile(const ORTCHAR_T* file_path, - size_t initializer_size_threshold); ///< Wraps OrtApi::ModelCompilationOptions_SetOutputModelExternalInitializersFile - ModelCompilationOptions& SetOutputModelBuffer(OrtAllocator* allocator, void** output_model_buffer_ptr, - size_t* output_model_buffer_size_ptr); ///< Wraps OrtApi::ModelCompilationOptions_SetOutputModelBuffer -}; - -/** \brief Compiles an input model to generate a model with EPContext nodes that execute EP-specific kernels. Wraps OrtApi::CompileModels. - * - * \param env: ORT environment object. - * \param model_compilation_options: Compilation options for a model. - * \return A Status indicating success or failure. - */ -Status CompileModel(const Env& env, const ModelCompilationOptions& model_compilation_options); - /** \brief Wrapper around ::OrtModelMetadata * */ @@ -2873,5 +2841,39 @@ struct Model : detail::ModelImpl { ConstModel GetConst() const { return ConstModel{this->p_}; } }; + +/** \brief Options object used when compiling a model. + * + * Wraps ::OrtModelCompilationOptions object and methods + */ +struct ModelCompilationOptions : detail::Base { + using Base = detail::Base; + using Base::Base; + + explicit ModelCompilationOptions(std::nullptr_t) {} ///< Create an empty ModelCompilationOptions object, must be assigned a valid one to be used. + + ModelCompilationOptions(const Env& env, const SessionOptions& session_options); ///< Wraps OrtApi::CreateModelCompilationOptionsFromSessionOptions + ModelCompilationOptions(const Env& env, ConstSessionOptions session_options); ///< Wraps OrtApi::CreateModelCompilationOptionsFromSessionOptions + + ModelCompilationOptions& SetInputModelPath(const ORTCHAR_T* input_model_path); ///< Wraps OrtApi::ModelCompilationOptions_SetInputModelPath + ModelCompilationOptions& SetInputModelFromBuffer(const void* input_model_data, + size_t input_model_data_size); ///< Wraps OrtApi::ModelCompilationOptions_SetInputModelFromBuffer + ModelCompilationOptions& SetInputModel(ConstModel input_model); ///< Wraps OrtApi::ModelCompilationOptions_SetInputModel + ModelCompilationOptions& SetEpContextEmbedMode(bool embed_ep_context_in_model); ///< Wraps OrtApi::ModelCompilationOptions_SetEpContextEmbedMode + ModelCompilationOptions& SetOutputModelPath(const ORTCHAR_T* output_model_path); ///< Wraps OrtApi::ModelCompilationOptions_SetOutputModelPath + ModelCompilationOptions& SetOutputModelExternalInitializersFile(const ORTCHAR_T* file_path, + size_t initializer_size_threshold); ///< Wraps OrtApi::ModelCompilationOptions_SetOutputModelExternalInitializersFile + ModelCompilationOptions& SetOutputModelBuffer(OrtAllocator* allocator, void** output_model_buffer_ptr, + size_t* output_model_buffer_size_ptr); ///< Wraps OrtApi::ModelCompilationOptions_SetOutputModelBuffer +}; + +/** \brief Compiles an input model to generate a model with EPContext nodes that execute EP-specific kernels. Wraps OrtApi::CompileModels. + * + * \param env: ORT environment object. + * \param model_compilation_options: Compilation options for a model. + * \return A Status indicating success or failure. + */ +Status CompileModel(const Env& env, const ModelCompilationOptions& model_compilation_options); + } // namespace Ort #include "onnxruntime_cxx_inline.h" diff --git a/include/onnxruntime/core/session/onnxruntime_cxx_inline.h b/include/onnxruntime/core/session/onnxruntime_cxx_inline.h index 0d0b3198a8736..66bfe6188798b 100644 --- a/include/onnxruntime/core/session/onnxruntime_cxx_inline.h +++ b/include/onnxruntime/core/session/onnxruntime_cxx_inline.h @@ -801,6 +801,12 @@ inline ModelCompilationOptions& ModelCompilationOptions::SetInputModelFromBuffer return *this; } +inline ModelCompilationOptions& ModelCompilationOptions::SetInputModel( + ConstModel input_model) { + Ort::ThrowOnError(GetCompileApi().ModelCompilationOptions_SetInputModel(this->p_, input_model)); + return *this; +} + inline ModelCompilationOptions& ModelCompilationOptions::SetOutputModelPath( const ORTCHAR_T* output_model_path) { Ort::ThrowOnError(GetCompileApi().ModelCompilationOptions_SetOutputModelPath(this->p_, output_model_path)); diff --git a/onnxruntime/core/graph/model_editor_api_types.h b/onnxruntime/core/graph/model_editor_api_types.h index d72bd13093b61..2ee84958e4d8b 100644 --- a/onnxruntime/core/graph/model_editor_api_types.h +++ b/onnxruntime/core/graph/model_editor_api_types.h @@ -1,5 +1,6 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. +#pragma once #include "core/common/inlined_containers_fwd.h" #include "core/framework/ort_value.h" diff --git a/onnxruntime/core/session/compile_api.cc b/onnxruntime/core/session/compile_api.cc index ad128fee6cc3d..892db3b38ca5e 100644 --- a/onnxruntime/core/session/compile_api.cc +++ b/onnxruntime/core/session/compile_api.cc @@ -10,6 +10,7 @@ #include "core/common/common.h" #include "core/session/allocator_adapters.h" #include "core/framework/error_code_helper.h" +#include "core/graph/model_editor_api_types.h" #include "core/session/abi_session_options_impl.h" #include "core/session/inference_session.h" #include "core/session/model_compilation_options.h" @@ -106,6 +107,27 @@ ORT_API_STATUS_IMPL(OrtCompileAPI::ModelCompilationOptions_SetInputModelFromBuff API_IMPL_END } +ORT_API_STATUS_IMPL(OrtCompileAPI::ModelCompilationOptions_SetInputModel, + _In_ OrtModelCompilationOptions* ort_model_compile_options, + _In_ const OrtModel* input_model) { + API_IMPL_BEGIN +#if !defined(ORT_MINIMAL_BUILD) + auto model_compile_options = reinterpret_cast(ort_model_compile_options); + + if (input_model == nullptr) { + return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, "Invalid input model: OrtModel pointer is null"); + } + + model_compile_options->SetInputOrtModel(*input_model); + return nullptr; +#else + ORT_UNUSED_PARAMETER(ort_model_compile_options); + ORT_UNUSED_PARAMETER(input_model_path); + return OrtApis::CreateStatus(ORT_NOT_IMPLEMENTED, "Compile API is not supported in this build"); +#endif // !defined(ORT_MINIMAL_BUILD) + API_IMPL_END +} + ORT_API_STATUS_IMPL(OrtCompileAPI::ModelCompilationOptions_SetOutputModelPath, _In_ OrtModelCompilationOptions* ort_model_compile_options, const ORTCHAR_T* output_model_path) { @@ -229,6 +251,7 @@ static constexpr OrtCompileApi ort_compile_api = { &OrtCompileAPI::ModelCompilationOptions_SetOutputModelBuffer, &OrtCompileAPI::ModelCompilationOptions_SetEpContextEmbedMode, &OrtCompileAPI::CompileModel, + &OrtCompileAPI::ModelCompilationOptions_SetInputModel, }; // checks that we don't violate the rule that the functions must remain in the slots they were originally assigned diff --git a/onnxruntime/core/session/compile_api.h b/onnxruntime/core/session/compile_api.h index b8c5211526b9d..537d8580b8649 100644 --- a/onnxruntime/core/session/compile_api.h +++ b/onnxruntime/core/session/compile_api.h @@ -28,5 +28,7 @@ ORT_API_STATUS_IMPL(ModelCompilationOptions_SetOutputModelBuffer, _In_ OrtModelC ORT_API_STATUS_IMPL(ModelCompilationOptions_SetEpContextEmbedMode, _In_ OrtModelCompilationOptions* model_compile_options, bool embed_ep_context_in_model); ORT_API_STATUS_IMPL(CompileModel, _In_ const OrtEnv* env, _In_ const OrtModelCompilationOptions* model_options); +ORT_API_STATUS_IMPL(ModelCompilationOptions_SetInputModel, _In_ OrtModelCompilationOptions* model_compile_options, + _In_ const OrtModel* input_model); } // namespace OrtCompileAPI diff --git a/onnxruntime/core/session/model_compilation_options.cc b/onnxruntime/core/session/model_compilation_options.cc index d0cb092f78843..480946491e1ae 100644 --- a/onnxruntime/core/session/model_compilation_options.cc +++ b/onnxruntime/core/session/model_compilation_options.cc @@ -29,14 +29,16 @@ ModelCompilationOptions::ModelCompilationOptions(const onnxruntime::Environment& } void ModelCompilationOptions::SetInputModelPath(const std::string& input_model_path) { - ResetInputModelSettings(); - input_model_path_ = input_model_path; + input_model_variant_ = input_model_path; } void ModelCompilationOptions::SetInputModelFromBuffer(const void* input_model_data, size_t input_model_data_size) { - ResetInputModelSettings(); - input_model_data_ = input_model_data; - input_model_data_size_ = input_model_data_size; + input_model_variant_ = gsl::span(reinterpret_cast(input_model_data), + input_model_data_size); +} + +void ModelCompilationOptions::SetInputOrtModel(const OrtModel& ort_model) { + input_model_variant_ = &ort_model; } Status ModelCompilationOptions::SetOutputModelPath(const std::string& output_model_path) { @@ -108,26 +110,18 @@ const OrtSessionOptions& ModelCompilationOptions::GetSessionOptions() const { return session_options_; } -bool ModelCompilationOptions::InputModelComesFromFile() const { - return !input_model_path_.empty(); -} - -const std::string& ModelCompilationOptions::GetInputModelPath() const { - return input_model_path_; -} - -const void* ModelCompilationOptions::GetInputModelData() const { - return input_model_data_; +const std::string* ModelCompilationOptions::TryGetInputModelPath() const { + return std::get_if(&input_model_variant_); } -size_t ModelCompilationOptions::GetInputModelDataSize() const { - return input_model_data_size_; +const gsl::span* ModelCompilationOptions::TryGetInputModelBuffer() const { + return std::get_if>(&input_model_variant_); } -void ModelCompilationOptions::ResetInputModelSettings() { - input_model_path_.clear(); - input_model_data_ = nullptr; - input_model_data_size_ = 0; +const OrtModel* ModelCompilationOptions::TryGetInputOrtModel() const { + const gsl::not_null* ort_model_ptr_ptr = std::get_if>( + &input_model_variant_); + return (ort_model_ptr_ptr == nullptr) ? nullptr : ort_model_ptr_ptr->get(); } Status ModelCompilationOptions::ResetOutputModelSettings() { @@ -139,68 +133,67 @@ Status ModelCompilationOptions::ResetOutputModelSettings() { return session_options_.value.config_options.AddConfigEntry(kOrtSessionOptionEpContextFilePath, ""); } -Status ModelCompilationOptions::CheckInputModelSettings() const { - const bool comes_from_file = !input_model_path_.empty(); - const bool comes_from_memory = input_model_data_ != nullptr; +Status ModelCompilationOptions::Check() const { + ORT_ENFORCE(session_options_.value.ep_context_gen_options.enable); + ORT_ENFORCE(session_options_.value.config_options.GetConfigOrDefault(kOrtSessionOptionsDisableModelCompile, "0") == "0"); - if (!comes_from_file && !comes_from_memory) { + // Check input model settings + if (std::holds_alternative(input_model_variant_)) { return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, - "Input model to compile must be loaded from either a file or a memory buffer"); + "Input model to compile must be loaded from either a file, a memory buffer, ", + "or an OrtModel instance"); } - if (comes_from_file && comes_from_memory) { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, - "Input model to compile must be loaded from either a file or a memory buffer, ", - "but not both."); - } + const std::string* input_model_path_ptr = TryGetInputModelPath(); + const gsl::span* input_model_buffer_ptr = TryGetInputModelBuffer(); + const OrtModel* ort_model = TryGetInputOrtModel(); - if (comes_from_file && !std::filesystem::exists(input_model_path_)) { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Input model path does not exist: ", input_model_path_); + if (input_model_path_ptr != nullptr && !std::filesystem::exists(*input_model_path_ptr)) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Input model path does not exist: ", *input_model_path_ptr); } - if (comes_from_memory && input_model_data_size_ == 0) { + if (input_model_buffer_ptr != nullptr && input_model_buffer_ptr->size() == 0) { return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Buffer for input model data has size 0"); } - return Status::OK(); -} + if (ort_model != nullptr && ort_model->graph->nodes.empty()) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Input OrtModel instance has no nodes"); + } -Status ModelCompilationOptions::CheckOutputModelSettings() const { + // Check output model settings const EpContextModelGenerationOptions& ep_context_gen_options = session_options_.value.ep_context_gen_options; - const bool explicit_writes_to_file = !ep_context_gen_options.output_model_file_path.empty(); - const bool writes_to_buffer = ep_context_gen_options.output_model_buffer_ptr != nullptr; + const bool explicit_output_to_file = !ep_context_gen_options.output_model_file_path.empty(); + const bool output_to_buffer = ep_context_gen_options.output_model_buffer_ptr != nullptr; - if (!explicit_writes_to_file && !writes_to_buffer) { + if (!explicit_output_to_file && !output_to_buffer && input_model_path_ptr != nullptr) { // User did not specify an output file or an output buffer. We default to generating an output file // with a name based on the input file name, so do not return an error. return Status::OK(); } - if (explicit_writes_to_file && writes_to_buffer) { + if (!explicit_output_to_file && !output_to_buffer) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "Unable to generate an output model path: require an input model path if the location " + "of the output model (e.g., file or buffer) is not specified."); + } + + if (explicit_output_to_file && output_to_buffer) { return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Output model to compile must be saved either to a file or to a buffer, but not both."); } - if (writes_to_buffer && ep_context_gen_options.output_model_buffer_size_ptr == nullptr) { + if (output_to_buffer && ep_context_gen_options.output_model_buffer_size_ptr == nullptr) { return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Invalid buffer configuration for output model: size pointer is null"); } - if (writes_to_buffer && ep_context_gen_options.output_model_buffer_allocator == nullptr) { + if (output_to_buffer && ep_context_gen_options.output_model_buffer_allocator == nullptr) { return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Invalid buffer configuration for output model: allocator is null"); } return Status::OK(); } - -Status ModelCompilationOptions::Check() const { - ORT_ENFORCE(session_options_.value.ep_context_gen_options.enable); - ORT_ENFORCE(session_options_.value.config_options.GetConfigOrDefault(kOrtSessionOptionsDisableModelCompile, "0") == "0"); - ORT_RETURN_IF_ERROR(CheckInputModelSettings()); - ORT_RETURN_IF_ERROR(CheckOutputModelSettings()); - return Status::OK(); -} } // namespace onnxruntime #endif // !defined(ORT_MINIMAL_BUILD) diff --git a/onnxruntime/core/session/model_compilation_options.h b/onnxruntime/core/session/model_compilation_options.h index 9238264003645..f3f73ef262665 100644 --- a/onnxruntime/core/session/model_compilation_options.h +++ b/onnxruntime/core/session/model_compilation_options.h @@ -4,11 +4,14 @@ #if !defined(ORT_MINIMAL_BUILD) #pragma once +#include #include #include +#include #include "core/common/status.h" #include "core/common/path_string.h" #include "core/framework/allocator.h" +#include "core/graph/model_editor_api_types.h" #include "core/session/abi_session_options_impl.h" #include "core/session/onnxruntime_c_api.h" #include "core/session/onnxruntime_session_options_config_keys.h" @@ -44,6 +47,13 @@ class ModelCompilationOptions { /// The size in bytes of the input model's buffer void SetInputModelFromBuffer(const void* input_model_data, size_t input_model_data_size); + /// + /// Sets the input OrtModel instance. + /// Overrides any previous call to SetInput*() + /// + /// OrtModel instance + void SetInputOrtModel(const OrtModel& ort_model); + /// /// Sets the file path to store the output/compiled ONNX model. /// Overrides any previous call to SetOutputModelPath() or SetOutputModelBuffer(). @@ -87,30 +97,25 @@ class ModelCompilationOptions { const OrtSessionOptions& GetSessionOptions() const; /// - /// Returns the file path to the input ONNX model. + /// Returns a pointer to the input model's path or nullptr if the input model + /// is not read from file. /// - /// input model's path - const std::string& GetInputModelPath() const; + /// input model's path or nullptr + const std::string* TryGetInputModelPath() const; /// - /// Returns true if the input model is read from a file. + /// Returns a pointer to the input model's bytes buffer or nullptr if the input model + /// is not read from a buffer. /// - /// true if input model comes from a file - bool InputModelComesFromFile() const; + /// input model's bytes buffer or nullptr + const gsl::span* TryGetInputModelBuffer() const; /// - /// Returns the buffer that contains the bytes for the input ONNX model. - /// Returns nullptr if the input model is not stored in a buffer. + /// Returns a pointer to the OrtModel instance for the input model or nullptr if + /// the input model is not stored as an OrtModel. /// - /// pointer to input model's buffer - const void* GetInputModelData() const; - - /// - /// Returns the size in bytes of the buffer that contains the input ONNX model. - /// Returns 0 if the input model is not stored in a buffer. - /// - /// input model buffer's size in bytes - size_t GetInputModelDataSize() const; + /// The OrtModel or nullptr + const OrtModel* TryGetInputOrtModel() const; /// /// Checks if the compilation options described by this object are valid. @@ -119,16 +124,16 @@ class ModelCompilationOptions { Status Check() const; private: - void ResetInputModelSettings(); Status ResetOutputModelSettings(); - Status CheckInputModelSettings() const; - Status CheckOutputModelSettings() const; const onnxruntime::Environment& env_; OrtSessionOptions session_options_; - std::string input_model_path_; - const void* input_model_data_ = nullptr; - size_t input_model_data_size_ = 0; + + std::variant, // input model in buffer + gsl::not_null> // input model created via OrtModelEditor + input_model_variant_{}; }; } // namespace onnxruntime #endif // !defined(ORT_MINIMAL_BUILD) diff --git a/onnxruntime/core/session/model_editor_c_api.cc b/onnxruntime/core/session/model_editor_c_api.cc index 7d5d45b7b531b..04b3a1cf87500 100644 --- a/onnxruntime/core/session/model_editor_c_api.cc +++ b/onnxruntime/core/session/model_editor_c_api.cc @@ -216,12 +216,7 @@ ORT_API_STATUS_IMPL(OrtModelEditorAPI::CreateSessionFromModel, _In_ const OrtEnv *out = nullptr; ORT_TRY { - sess = std::make_unique( - options == nullptr ? onnxruntime::SessionOptions() : options->value, - env->GetEnvironment()); - - ORT_API_RETURN_IF_STATUS_NOT_OK(sess->Load(*model)); - + ORT_API_RETURN_IF_STATUS_NOT_OK(onnxruntime::CreateSessionFromModel(options, env->GetEnvironment(), model, sess)); ORT_API_RETURN_IF_ERROR(InitializeSession(options, *sess)); *out = reinterpret_cast(sess.release()); diff --git a/onnxruntime/core/session/utils.cc b/onnxruntime/core/session/utils.cc index 8ca4ef6af1f44..46f5497ca2035 100644 --- a/onnxruntime/core/session/utils.cc +++ b/onnxruntime/core/session/utils.cc @@ -262,22 +262,48 @@ OrtStatus* InitializeSession(_In_ const OrtSessionOptions* options, namespace onnxruntime { #if !defined(ORT_MINIMAL_BUILD) +// Creates a session from an OrtModel (created by model editor API). +Status CreateSessionFromModel(const OrtSessionOptions* options, + const Environment& env, + const OrtModel* ort_model, + std::unique_ptr& sess) { + sess = std::make_unique( + options == nullptr ? onnxruntime::SessionOptions() : options->value, + env); + +#if !defined(ORT_MINIMAL_BUILD) || defined(ORT_MINIMAL_BUILD_CUSTOM_OPS) + if (options && !options->custom_op_domains_.empty()) { + ORT_RETURN_IF_ERROR(sess->AddCustomOpDomains(options->custom_op_domains_)); + } +#endif + + ORT_RETURN_IF_ERROR(sess->Load(*ort_model)); + return Status::OK(); +} + Status CompileModel(const Environment& env, const ModelCompilationOptions& model_compile_options) { ORT_RETURN_IF_ERROR(model_compile_options.Check()); std::unique_ptr session; const OrtSessionOptions* session_options = &model_compile_options.GetSessionOptions(); - if (model_compile_options.InputModelComesFromFile()) { - PathString input_model_path = ToPathString(model_compile_options.GetInputModelPath()); + if (const std::string* model_path_ptr = model_compile_options.TryGetInputModelPath(); + model_path_ptr != nullptr) { + PathString input_model_path = ToPathString(*model_path_ptr); ORT_RETURN_IF_ERROR(ToStatus(CreateSessionAndLoadModelImpl(session_options, env, input_model_path.c_str(), nullptr, 0, session))); - } else { + } else if (const gsl::span* model_buffer_ptr = model_compile_options.TryGetInputModelBuffer(); + model_buffer_ptr != nullptr) { ORT_RETURN_IF_ERROR(ToStatus(CreateSessionAndLoadModelImpl(session_options, env, nullptr, - model_compile_options.GetInputModelData(), - model_compile_options.GetInputModelDataSize(), + model_buffer_ptr->data(), + model_buffer_ptr->size(), session))); + } else if (const OrtModel* ort_model = model_compile_options.TryGetInputOrtModel(); ort_model != nullptr) { + ORT_RETURN_IF_ERROR(CreateSessionFromModel(session_options, env, ort_model, session)); + } else { + return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Invalid input model passed to CompileModel(): ", + "expected a file path, a buffer, or an OrtModel, but received an unknown kind of model."); } ORT_RETURN_IF_ERROR(ToStatus(InitializeSession(session_options, *session))); diff --git a/onnxruntime/core/session/utils.h b/onnxruntime/core/session/utils.h index 5a5dcae9165ed..a934f19ab1535 100644 --- a/onnxruntime/core/session/utils.h +++ b/onnxruntime/core/session/utils.h @@ -16,13 +16,14 @@ struct OrtSessionOptions; struct OrtStatus; struct OrtPrepackedWeightsContainer; namespace onnxruntime { +class Environment; class InferenceSession; class ModelCompilationOptions; } // namespace onnxruntime #if !defined(ORT_MINIMAL_BUILD) +struct OrtModel; namespace onnxruntime { -class Environment; class EpLibrary; class EpFactoryInternal; struct IExecutionProviderFactory; @@ -44,6 +45,19 @@ OrtStatus* InitializeSession(_In_ const OrtSessionOptions* options, #if !defined(ORT_MINIMAL_BUILD) namespace onnxruntime { +/// +/// Creates a session from an OrtModel instance. +/// +/// Optional session options +/// Reference to the Environment +/// The OrtModel instance +/// The session to create and initialize +/// A Status indicating an error or success. +Status CreateSessionFromModel(const OrtSessionOptions* options, + const onnxruntime::Environment& env, + const OrtModel* ort_model, + std::unique_ptr& sess); + /// /// Compiles an ONNX model into a model with EPContext nodes. Each EPContext node represents a subgraph compiled for /// a specific execution provider. diff --git a/onnxruntime/test/providers/qnn/qnn_ep_context_test.cc b/onnxruntime/test/providers/qnn/qnn_ep_context_test.cc index 8d840b1a3d45f..67e86cf60a63d 100644 --- a/onnxruntime/test/providers/qnn/qnn_ep_context_test.cc +++ b/onnxruntime/test/providers/qnn/qnn_ep_context_test.cc @@ -557,6 +557,69 @@ TEST_F(QnnHTPBackendTests, CompileApi_FromSessionOptions_OutputModelBuffer_Outpu allocator.Free(output_model_buffer); } +// Tests compiling an OrtModel created using the OrtModelEditor API. +TEST_F(QnnHTPBackendTests, CompileApi_InputOrtModel_OutputFile) { + std::vector>> weights; // Model weights must remain valid through inference + + // Create OrtModel with a Gemm. X input is 3x4, Y initializer is 4x8, Z output is 3x8. + Ort::Graph graph; + std::vector graph_inputs; + std::vector graph_outputs; + + Ort::TensorTypeAndShapeInfo x_tensor_info(ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT, + {3, 4}, nullptr); + auto x_type_info = Ort::TypeInfo::CreateTensorInfo(x_tensor_info.GetConst()); + graph_inputs.emplace_back("X", x_type_info.GetConst()); + + Ort::TensorTypeAndShapeInfo z_tensor_info(ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT, + {3, 8}, nullptr); + auto z_type_info = Ort::TypeInfo::CreateTensorInfo(z_tensor_info.GetConst()); + graph_outputs.emplace_back("Z", z_type_info.GetConst()); + + graph.SetInputs(graph_inputs); + graph.SetOutputs(graph_outputs); + + std::vector attrs; + Ort::Node node("Gemm", onnxruntime::kOnnxDomain, "Gemm1", {"X", "Y"}, {"Z"}, attrs); + graph.AddNode(node); + + std::vector y_dims = {4, 8}; + weights.emplace_back(std::make_unique>(32)); + auto& y_values = *weights.back(); + std::iota(y_values.begin(), y_values.end(), 1.0f); + + auto mem_info = Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeDefault); + auto y_tensor = Ort::Value::CreateTensor(mem_info, y_values.data(), y_values.size(), y_dims.data(), y_dims.size()); + graph.AddInitializer("Y", y_tensor, /*data is external*/ false); // TODO: external data does not serialize to proto (error) + + std::vector opsets{{onnxruntime::kOnnxDomain, 18}, {onnxruntime::kMSDomain, 1}}; + Ort::Model model(opsets); + model.AddGraph(graph); + + // Initialize session options with QNN EP + Ort::SessionOptions so; + ProviderOptions provider_options; + provider_options["backend_type"] = "htp"; + provider_options["offload_graph_io_quantization"] = "0"; + so.AppendExecutionProvider("QNN", provider_options); + + const ORTCHAR_T* output_model_file = ORT_TSTR("compileapi_ortmodel_ctx.onnx"); + std::filesystem::remove(output_model_file); + + // Create model compilation options from the session options. + Ort::ModelCompilationOptions compile_options(*ort_env, so); + compile_options.SetInputModel(model.GetConst()); + compile_options.SetOutputModelPath(output_model_file); + compile_options.SetEpContextEmbedMode(true); + + // Compile the model. + Ort::Status status = Ort::CompileModel(*ort_env, compile_options); + ASSERT_TRUE(status.IsOK()) << status.GetErrorMessage(); + + // Make sure the compiled model was generated and has the expected number of EPContext nodes. + ASSERT_TRUE(std::filesystem::exists(output_model_file)); +} + // Test that models with 1 non-quantized FusedMatMul node and 1 quantized Add node can still generate the context binary // The generated Onnx model has 1 FusedMatMul node and 1 EPContext node TEST_F(QnnHTPBackendTests, QnnContextBinaryMultiPartitionSupport1) { From ffcfab99cc10a32c4a573bd19041a3666ddbe5b9 Mon Sep 17 00:00:00 2001 From: Adrian Lizarraga Date: Tue, 13 May 2025 01:05:06 -0700 Subject: [PATCH 02/14] Update onnxruntime/core/session/compile_api.cc --- onnxruntime/core/session/compile_api.cc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/onnxruntime/core/session/compile_api.cc b/onnxruntime/core/session/compile_api.cc index 892db3b38ca5e..29caf93a6ebe1 100644 --- a/onnxruntime/core/session/compile_api.cc +++ b/onnxruntime/core/session/compile_api.cc @@ -122,7 +122,7 @@ ORT_API_STATUS_IMPL(OrtCompileAPI::ModelCompilationOptions_SetInputModel, return nullptr; #else ORT_UNUSED_PARAMETER(ort_model_compile_options); - ORT_UNUSED_PARAMETER(input_model_path); + ORT_UNUSED_PARAMETER(input_model); return OrtApis::CreateStatus(ORT_NOT_IMPLEMENTED, "Compile API is not supported in this build"); #endif // !defined(ORT_MINIMAL_BUILD) API_IMPL_END From 025474a8c0c9fb5a4a7bc51edba0284cdfa57f01 Mon Sep 17 00:00:00 2001 From: adrianlizarraga Date: Tue, 13 May 2025 01:18:56 -0700 Subject: [PATCH 03/14] Fix tensorprotoutils when reading ext data from a memory buffer --- onnxruntime/core/framework/tensorprotoutils.cc | 17 ++++++++++++----- .../test/providers/qnn/qnn_ep_context_test.cc | 3 ++- 2 files changed, 14 insertions(+), 6 deletions(-) diff --git a/onnxruntime/core/framework/tensorprotoutils.cc b/onnxruntime/core/framework/tensorprotoutils.cc index 94a2a6677358e..83ad58dc610e9 100644 --- a/onnxruntime/core/framework/tensorprotoutils.cc +++ b/onnxruntime/core/framework/tensorprotoutils.cc @@ -179,11 +179,18 @@ Status ReadExternalDataForTensor(const ONNX_NAMESPACE::TensorProto& tensor_proto GetExternalDataInfo(tensor_proto, tensor_proto_dir, external_file_path, file_offset, tensor_byte_size)); unpacked_tensor.resize(tensor_byte_size); - ORT_RETURN_IF_ERROR(onnxruntime::Env::Default().ReadFileIntoBuffer( - external_file_path.c_str(), - file_offset, - tensor_byte_size, - gsl::make_span(reinterpret_cast(unpacked_tensor.data()), tensor_byte_size))); + if (external_file_path == onnxruntime::utils::kTensorProtoMemoryAddressTag) { + // the value in location is the memory address of the data + const void* ext_data_buf = reinterpret_cast(file_offset); + const size_t ext_data_len = tensor_byte_size; + std::memcpy(unpacked_tensor.data(), ext_data_buf, ext_data_len); + } else { + ORT_RETURN_IF_ERROR(onnxruntime::Env::Default().ReadFileIntoBuffer( + external_file_path.c_str(), + file_offset, + tensor_byte_size, + gsl::make_span(reinterpret_cast(unpacked_tensor.data()), tensor_byte_size))); + } return Status::OK(); } diff --git a/onnxruntime/test/providers/qnn/qnn_ep_context_test.cc b/onnxruntime/test/providers/qnn/qnn_ep_context_test.cc index 67e86cf60a63d..b095d6e7b3c4b 100644 --- a/onnxruntime/test/providers/qnn/qnn_ep_context_test.cc +++ b/onnxruntime/test/providers/qnn/qnn_ep_context_test.cc @@ -560,6 +560,7 @@ TEST_F(QnnHTPBackendTests, CompileApi_FromSessionOptions_OutputModelBuffer_Outpu // Tests compiling an OrtModel created using the OrtModelEditor API. TEST_F(QnnHTPBackendTests, CompileApi_InputOrtModel_OutputFile) { std::vector>> weights; // Model weights must remain valid through inference + // if we want to avoid a copy. // Create OrtModel with a Gemm. X input is 3x4, Y initializer is 4x8, Z output is 3x8. Ort::Graph graph; @@ -590,7 +591,7 @@ TEST_F(QnnHTPBackendTests, CompileApi_InputOrtModel_OutputFile) { auto mem_info = Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeDefault); auto y_tensor = Ort::Value::CreateTensor(mem_info, y_values.data(), y_values.size(), y_dims.data(), y_dims.size()); - graph.AddInitializer("Y", y_tensor, /*data is external*/ false); // TODO: external data does not serialize to proto (error) + graph.AddInitializer("Y", y_tensor, /*data is external, avoid copy*/ true); std::vector opsets{{onnxruntime::kOnnxDomain, 18}, {onnxruntime::kMSDomain, 1}}; Ort::Model model(opsets); From 94596eb6c17f09fd75f8a04f190cfb700961262b Mon Sep 17 00:00:00 2001 From: adrianlizarraga Date: Tue, 13 May 2025 01:23:03 -0700 Subject: [PATCH 04/14] Clean up --- onnxruntime/core/framework/tensorprotoutils.cc | 3 +-- onnxruntime/core/session/utils.h | 2 +- 2 files changed, 2 insertions(+), 3 deletions(-) diff --git a/onnxruntime/core/framework/tensorprotoutils.cc b/onnxruntime/core/framework/tensorprotoutils.cc index 83ad58dc610e9..2efa03bd3adf1 100644 --- a/onnxruntime/core/framework/tensorprotoutils.cc +++ b/onnxruntime/core/framework/tensorprotoutils.cc @@ -182,8 +182,7 @@ Status ReadExternalDataForTensor(const ONNX_NAMESPACE::TensorProto& tensor_proto if (external_file_path == onnxruntime::utils::kTensorProtoMemoryAddressTag) { // the value in location is the memory address of the data const void* ext_data_buf = reinterpret_cast(file_offset); - const size_t ext_data_len = tensor_byte_size; - std::memcpy(unpacked_tensor.data(), ext_data_buf, ext_data_len); + std::memcpy(unpacked_tensor.data(), ext_data_buf, tensor_byte_size); } else { ORT_RETURN_IF_ERROR(onnxruntime::Env::Default().ReadFileIntoBuffer( external_file_path.c_str(), diff --git a/onnxruntime/core/session/utils.h b/onnxruntime/core/session/utils.h index a934f19ab1535..9c5f7e34a3d80 100644 --- a/onnxruntime/core/session/utils.h +++ b/onnxruntime/core/session/utils.h @@ -16,7 +16,6 @@ struct OrtSessionOptions; struct OrtStatus; struct OrtPrepackedWeightsContainer; namespace onnxruntime { -class Environment; class InferenceSession; class ModelCompilationOptions; } // namespace onnxruntime @@ -24,6 +23,7 @@ class ModelCompilationOptions; #if !defined(ORT_MINIMAL_BUILD) struct OrtModel; namespace onnxruntime { +class Environment; class EpLibrary; class EpFactoryInternal; struct IExecutionProviderFactory; From 134a702b11cec2d38008eca1209bc81a5d19f283 Mon Sep 17 00:00:00 2001 From: adrianlizarraga Date: Tue, 13 May 2025 18:36:15 -0700 Subject: [PATCH 05/14] Add C API to write compiled model to stream --- .../core/session/onnxruntime_c_api.h | 31 +++++++++ .../core/framework/ep_context_options.cc | 35 ++++++++++ .../core/framework/ep_context_options.h | 68 +++++++++++++++++++ .../core/framework/graph_partitioner.cc | 49 +++++++------ .../core/framework/graph_partitioner.h | 7 +- onnxruntime/core/framework/session_options.cc | 13 +--- onnxruntime/core/framework/session_options.h | 28 ++------ onnxruntime/core/session/compile_api.cc | 23 +++++++ onnxruntime/core/session/compile_api.h | 2 + .../core/session/model_compilation_options.cc | 63 ++++++++++------- .../core/session/model_compilation_options.h | 7 ++ onnxruntime/core/session/utils.cc | 6 +- .../test/framework/session_state_test.cc | 3 +- 13 files changed, 247 insertions(+), 88 deletions(-) create mode 100644 onnxruntime/core/framework/ep_context_options.cc create mode 100644 onnxruntime/core/framework/ep_context_options.h diff --git a/include/onnxruntime/core/session/onnxruntime_c_api.h b/include/onnxruntime/core/session/onnxruntime_c_api.h index c0cc59739e076..38489ad7618e2 100644 --- a/include/onnxruntime/core/session/onnxruntime_c_api.h +++ b/include/onnxruntime/core/session/onnxruntime_c_api.h @@ -453,6 +453,24 @@ typedef OrtStatus*(ORT_API_CALL* EpSelectionDelegate)(_In_ const OrtEpDevice** e _Out_ size_t* num_selected, _In_ void* state); +/** \brief Function that writes a buffer to a stream. + * + * \param state Opaque pointer holding the state for the user's stream. + * \param buffer The buffer to write to the stream. + * \param buffer_num_bytes The size of the buffer in bytes. + * \param num_bytes_written Output parameter that should be set to the number of bytes written to + * the stream. ONNX Runtime will continuously call this write function until + * all bytes have been written to the stream. + * + * \return OrtStatus* Write status. Return nullptr on success. + * Use CreateStatus to provide error info. Use ORT_FAIL as the error code. + * ORT will release the OrtStatus* if not null. + */ +typedef OrtStatus*(ORT_API_CALL* WriteToStreamFunc)(_In_ void* state, + _In_ const void* buffer, + _In_ size_t buffer_num_bytes, + _Out_ size_t* num_bytes_written); + /** \brief Algorithm to use for cuDNN Convolution Op */ typedef enum OrtCudnnConvAlgoSearch { @@ -5976,6 +5994,19 @@ struct OrtCompileApi { */ ORT_API2_STATUS(ModelCompilationOptions_SetInputModel, _In_ OrtModelCompilationOptions* model_compile_options, _In_ const OrtModel* input_model); + + /** \brief Sets the WriteToStreamFunc function that ONNX Runtime should call to write the output model to a stream. + * + * \param[in] model_compile_options The OrtModelCompilationOptions instance. + * \param[in] write_stream_func The WriteToStreamFunc function to call when writing out the model. + * \param[in] state Opaque state passed as the first argument to WriteToStreamFunc. + * + * \snippet{doc} snippets.dox OrtStatus Return Value + * + * \since Version 1.22. + */ + ORT_API2_STATUS(ModelCompilationOptions_SetOutputModelStream, _In_ OrtModelCompilationOptions* model_compile_options, + _In_ WriteToStreamFunc write_stream_func, _In_ void* state); }; ORT_RUNTIME_CLASS(Ep); diff --git a/onnxruntime/core/framework/ep_context_options.cc b/onnxruntime/core/framework/ep_context_options.cc new file mode 100644 index 0000000000000..c22fda550f34e --- /dev/null +++ b/onnxruntime/core/framework/ep_context_options.cc @@ -0,0 +1,35 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "core/framework/ep_context_options.h" +#include "core/session/onnxruntime_session_options_config_keys.h" + +namespace onnxruntime { +namespace epctx { +ModelGenOptions::ModelGenOptions(const ConfigOptions& config_options) { + enable = config_options.GetConfigOrDefault(kOrtSessionOptionEpContextEnable, "0") == "1"; + output_model_location = config_options.GetConfigOrDefault(kOrtSessionOptionEpContextFilePath, ""); + output_external_initializers_file_path = config_options.GetConfigOrDefault( + kOrtSessionOptionsEpContextModelExternalInitializersFileName, ""); + output_external_initializer_size_threshold = 0; + embed_ep_context_in_model = config_options.GetConfigOrDefault(kOrtSessionOptionEpContextEmbedMode, "0") == "1"; +} + +bool ModelGenOptions::HasOutputModelLocation() const { + return !std::holds_alternative(output_model_location); +} + +const std::string* ModelGenOptions::TryGetOutputModelPath() const { + return std::get_if(&output_model_location); +} + +const BufferHolder* ModelGenOptions::TryGetOutputModelBuffer() const { + return std::get_if(&output_model_location); +} + +const StreamHolder* ModelGenOptions::TryGetOutputModelStream() const { + return std::get_if(&output_model_location); +} + +} // namespace epctx +} // namespace onnxruntime diff --git a/onnxruntime/core/framework/ep_context_options.h b/onnxruntime/core/framework/ep_context_options.h new file mode 100644 index 0000000000000..6ab6031a6a1e4 --- /dev/null +++ b/onnxruntime/core/framework/ep_context_options.h @@ -0,0 +1,68 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include +#include +#include "core/framework/allocator.h" +#include "core/framework/config_options.h" + +namespace onnxruntime { +namespace epctx { +struct BufferHolder { + void** buffer_ptr = nullptr; + size_t* buffer_size_ptr = nullptr; + AllocatorPtr buffer_allocator = nullptr; +}; + +struct StreamHolder { + WriteToStreamFunc write_func; + void* state; // Opaque pointer to user's stream state. Passed as first argument to write_func. +}; + +/* +class WriteFuncStreamBuf : public std::streambuf { + public: + WriteFuncStreamBuf(StreamHolder write_func_holder); + ~WriteFuncStreamBuf(); + + protected: + int_type overflow(int_type ch) override; + int sync() override; + + private: + int FlushBuffer(); + + StreamHolder write_func_holder_; + std::array buffer_; +}; +*/ + +struct ModelGenOptions { + ModelGenOptions() = default; + + // Initializes from string key/value pairs in session config options. + explicit ModelGenOptions(const ConfigOptions& config_options); + + bool enable = false; + bool overwrite_existing_output_file = false; + bool error_if_no_compiled_nodes = false; + bool embed_ep_context_in_model = false; + + std::variant // Function to write the output model to a user's stream. + output_model_location; + + std::string output_external_initializers_file_path; + size_t output_external_initializer_size_threshold = 0; + + bool HasOutputModelLocation() const; + const std::string* TryGetOutputModelPath() const; + const BufferHolder* TryGetOutputModelBuffer() const; + const StreamHolder* TryGetOutputModelStream() const; +}; +} // namespace epctx +} // namespace onnxruntime diff --git a/onnxruntime/core/framework/graph_partitioner.cc b/onnxruntime/core/framework/graph_partitioner.cc index 8ed5eeaa8d44f..280a92092793a 100644 --- a/onnxruntime/core/framework/graph_partitioner.cc +++ b/onnxruntime/core/framework/graph_partitioner.cc @@ -9,6 +9,7 @@ #include "core/common/inlined_containers.h" #include "core/common/string_utils.h" #include "core/framework/compute_capability.h" +#include "core/framework/ep_context_options.h" #include "core/framework/execution_providers.h" #include "core/framework/func_kernel.h" #include "core/framework/kernel_lookup.h" @@ -794,7 +795,7 @@ static Status GetValidatedEpContextPath(const std::filesystem::path& ep_context_ static Status CreateEpContextModel(const ExecutionProviders& execution_providers, const Graph& graph, - const EpContextModelGenerationOptions& ep_context_gen_options, + const epctx::ModelGenOptions& ep_context_gen_options, const logging::Logger& logger) { InlinedVector all_ep_context_nodes; for (const auto& ep : execution_providers) { @@ -824,15 +825,16 @@ static Status CreateEpContextModel(const ExecutionProviders& execution_providers return std::make_pair(false, static_cast(nullptr)); }; - bool saving_to_buffer = ep_context_gen_options.output_model_buffer_ptr != nullptr && - ep_context_gen_options.output_model_buffer_size_ptr != nullptr && - ep_context_gen_options.output_model_buffer_allocator != nullptr; + const auto* output_buffer_holder = ep_context_gen_options.TryGetOutputModelBuffer(); + const auto* output_stream_holder = ep_context_gen_options.TryGetOutputModelStream(); + const std::string* output_model_path_ptr = ep_context_gen_options.TryGetOutputModelPath(); - std::filesystem::path context_cache_path; - if (!saving_to_buffer || !graph.ModelPath().empty()) { - ORT_RETURN_IF_ERROR(GetValidatedEpContextPath(ep_context_gen_options.output_model_file_path, + std::filesystem::path valid_output_model_path; + if (output_model_path_ptr != nullptr || !graph.ModelPath().empty()) { + std::string output_model_path = (output_model_path_ptr != nullptr) ? *output_model_path_ptr : ""; + ORT_RETURN_IF_ERROR(GetValidatedEpContextPath(output_model_path, graph.ModelPath(), - context_cache_path, + valid_output_model_path, ep_context_gen_options.overwrite_existing_output_file)); } @@ -894,25 +896,27 @@ static Status CreateEpContextModel(const ExecutionProviders& execution_providers ModelSavingOptions model_saving_options{ini_size_threshold}; - if (saving_to_buffer) { + if (output_buffer_holder != nullptr) { ORT_RETURN_IF_ERROR(ep_context_model.MainGraph().Resolve()); // TODO(adrianlizarraga): Investigate if we can make this more memory efficient. // May be able to use allocator to directly allocate the ModelProto to avoid a copy. ONNX_NAMESPACE::ModelProto model_proto = ep_context_model.ToGraphProtoWithExternalInitializers(external_ini_path, - context_cache_path, + valid_output_model_path, model_saving_options); size_t buffer_size = model_proto.ByteSizeLong(); ORT_RETURN_IF(buffer_size > static_cast(std::numeric_limits::max()), "Cannot serialize ONNX ModelProto larger than 2GB"); - AllocatorPtr allocator = ep_context_gen_options.output_model_buffer_allocator; + AllocatorPtr allocator = output_buffer_holder->buffer_allocator; IAllocatorUniquePtr buffer = IAllocator::MakeUniquePtr(allocator, buffer_size); model_proto.SerializeToArray(buffer.get(), static_cast(buffer_size)); - *ep_context_gen_options.output_model_buffer_size_ptr = buffer_size; - *ep_context_gen_options.output_model_buffer_ptr = buffer.release(); + *output_buffer_holder->buffer_size_ptr = buffer_size; + *output_buffer_holder->buffer_ptr = buffer.release(); + } else if (output_stream_holder != nullptr) { + // TODO } else { - ORT_RETURN_IF_ERROR(Model::SaveWithExternalInitializers(ep_context_model, context_cache_path, + ORT_RETURN_IF_ERROR(Model::SaveWithExternalInitializers(ep_context_model, valid_output_model_path, external_ini_path, model_saving_options)); } @@ -1164,7 +1168,7 @@ Status GraphPartitioner::Partition(Graph& graph, FuncManager& func_mgr, const ConfigOptions& config_options, const logging::Logger& logger, Mode mode, - const EpContextModelGenerationOptions& ep_context_gen_options, + const epctx::ModelGenOptions& ep_context_gen_options, const layout_transformation::DebugGraphFn& debug_graph_fn) const { // It is a greedy partitioning algorithm per provider preferences user provided when calling ONNX RUNTIME right now. // 1. Execution providers' capabilities are checked one by one. @@ -1211,12 +1215,15 @@ Status GraphPartitioner::Partition(Graph& graph, FuncManager& func_mgr, if (mode == Mode::kNormal || mode == Mode::kAssignOnly) { #if !defined(ORT_MINIMAL_BUILD) - if (ep_context_gen_options.enable && ep_context_gen_options.output_model_buffer_ptr == nullptr) { - // Check before EP compile graphs - std::filesystem::path context_cache_path; - ORT_RETURN_IF_ERROR(GetValidatedEpContextPath(ep_context_gen_options.output_model_file_path, graph.ModelPath(), - context_cache_path, - ep_context_gen_options.overwrite_existing_output_file)); + if (ep_context_gen_options.enable) { + if (const std::string* output_model_path_ptr = ep_context_gen_options.TryGetOutputModelPath(); + output_model_path_ptr != nullptr) { + // Check before EP compile graphs + std::filesystem::path context_cache_path; + ORT_RETURN_IF_ERROR(GetValidatedEpContextPath(*output_model_path_ptr, graph.ModelPath(), + context_cache_path, + ep_context_gen_options.overwrite_existing_output_file)); + } } // We use this only if Resource Aware Partitioning is enabled for any of the EPs diff --git a/onnxruntime/core/framework/graph_partitioner.h b/onnxruntime/core/framework/graph_partitioner.h index 6e36d79701fd7..abe46cea58ab2 100644 --- a/onnxruntime/core/framework/graph_partitioner.h +++ b/onnxruntime/core/framework/graph_partitioner.h @@ -15,7 +15,10 @@ class ExecutionProviders; class KernelRegistryManager; class Model; struct ConfigOptions; -struct EpContextModelGenerationOptions; + +namespace epctx { +struct ModelGenOptions; +} class GraphPartitioner { public: @@ -50,7 +53,7 @@ class GraphPartitioner { const ConfigOptions& config_options, const logging::Logger& logger, Mode mode = Mode::kNormal, - const EpContextModelGenerationOptions& ep_context_gen_options = {}, + const epctx::ModelGenOptions& ep_context_gen_options = {}, const layout_transformation::DebugGraphFn& debug_graph_fn = {}) const; bool IsLoadCancellationFlagSet() const { diff --git a/onnxruntime/core/framework/session_options.cc b/onnxruntime/core/framework/session_options.cc index 231eb47603838..63f928d52d788 100644 --- a/onnxruntime/core/framework/session_options.cc +++ b/onnxruntime/core/framework/session_options.cc @@ -99,20 +99,11 @@ void SessionOptions::AddCustomOpLibraryHandle(PathString library_name, void* lib } #endif // !defined(ORT_MINIMAL_BUILD) || defined(ORT_MINIMAL_BUILD_CUSTOM_OPS) -EpContextModelGenerationOptions::EpContextModelGenerationOptions(const ConfigOptions& config_options) { - enable = config_options.GetConfigOrDefault(kOrtSessionOptionEpContextEnable, "0") == "1"; - output_model_file_path = config_options.GetConfigOrDefault(kOrtSessionOptionEpContextFilePath, ""); - output_external_initializers_file_path = config_options.GetConfigOrDefault( - kOrtSessionOptionsEpContextModelExternalInitializersFileName, ""); - output_external_initializer_size_threshold = 0; - embed_ep_context_in_model = config_options.GetConfigOrDefault(kOrtSessionOptionEpContextEmbedMode, "0") == "1"; -} - -EpContextModelGenerationOptions SessionOptions::GetEpContextGenerationOptions() const { +epctx::ModelGenOptions SessionOptions::GetEpContextGenerationOptions() const { if (this->has_explicit_ep_context_gen_options) { return this->ep_context_gen_options; } - return EpContextModelGenerationOptions(this->config_options); + return epctx::ModelGenOptions(this->config_options); } } // namespace onnxruntime diff --git a/onnxruntime/core/framework/session_options.h b/onnxruntime/core/framework/session_options.h index 89a43c4f71ee6..e44ac95c2c890 100644 --- a/onnxruntime/core/framework/session_options.h +++ b/onnxruntime/core/framework/session_options.h @@ -11,8 +11,8 @@ #include #include #include "core/common/inlined_containers.h" -#include "core/framework/allocator.h" #include "core/framework/config_options.h" +#include "core/framework/ep_context_options.h" #include "core/framework/ort_value.h" #include "core/session/onnxruntime_c_api.h" #include "core/optimizer/graph_transformer_level.h" @@ -70,26 +70,6 @@ struct FreeDimensionOverride { using CheckLoadCancellationFn = std::function; -struct EpContextModelGenerationOptions { - EpContextModelGenerationOptions() = default; - - // Initializes from string key/value pairs in session config options. - explicit EpContextModelGenerationOptions(const ConfigOptions& config_options); - - bool enable = false; - bool overwrite_existing_output_file = false; - bool error_if_no_compiled_nodes = false; - bool embed_ep_context_in_model = false; - - std::string output_model_file_path; - void** output_model_buffer_ptr = nullptr; - size_t* output_model_buffer_size_ptr = nullptr; - AllocatorPtr output_model_buffer_allocator = nullptr; - - std::string output_external_initializers_file_path; - size_t output_external_initializer_size_threshold = 0; -}; - struct EpSelectionPolicy { // flag to detect that a policy was set by the user. // need to preserve current behavior of defaulting to CPU EP if no EPs are explicitly registered @@ -239,12 +219,12 @@ struct SessionOptions { // Options for generating compile EPContext models were previously stored in session_option.configs as // string key/value pairs. To support more advanced options, such as setting input/output buffers, we - // now have to store EPContext options in a struct of type EpContextModelGenerationOptions. + // now have to store EPContext options in a struct of type epctx::ModelGenOptions. // The function GetEpContextGenerationOptions() handles conversion of string key/value pairs to the new // struct type. bool has_explicit_ep_context_gen_options = false; - EpContextModelGenerationOptions ep_context_gen_options = {}; - EpContextModelGenerationOptions GetEpContextGenerationOptions() const; + epctx::ModelGenOptions ep_context_gen_options = {}; + epctx::ModelGenOptions GetEpContextGenerationOptions() const; }; inline std::ostream& operator<<(std::ostream& os, const SessionOptions& session_options) { diff --git a/onnxruntime/core/session/compile_api.cc b/onnxruntime/core/session/compile_api.cc index 29caf93a6ebe1..d571edd7647d6 100644 --- a/onnxruntime/core/session/compile_api.cc +++ b/onnxruntime/core/session/compile_api.cc @@ -207,6 +207,28 @@ ORT_API_STATUS_IMPL(OrtCompileAPI::ModelCompilationOptions_SetOutputModelBuffer, API_IMPL_END } +ORT_API_STATUS_IMPL(OrtCompileAPI::ModelCompilationOptions_SetOutputModelStream, + _In_ OrtModelCompilationOptions* ort_model_compile_options, + _In_ WriteToStreamFunc write_stream_func, _In_ void* state) { + API_IMPL_BEGIN +#if !defined(ORT_MINIMAL_BUILD) + auto model_compile_options = reinterpret_cast(ort_model_compile_options); + + if (write_stream_func == nullptr) { + return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, "WriteToStreamFunc function for output model is null"); + } + + model_compile_options->SetOutputModelStream(write_stream_func, state); + return nullptr; +#else + ORT_UNUSED_PARAMETER(ort_model_compile_options); + ORT_UNUSED_PARAMETER(write_stream_func); + ORT_UNUSED_PARAMETER(state); + return OrtApis::CreateStatus(ORT_NOT_IMPLEMENTED, "Compile API is not supported in this build"); +#endif // !defined(ORT_MINIMAL_BUILD) + API_IMPL_END +} + ORT_API_STATUS_IMPL(OrtCompileAPI::ModelCompilationOptions_SetEpContextEmbedMode, _In_ OrtModelCompilationOptions* ort_model_compile_options, bool embed_ep_context_in_model) { @@ -252,6 +274,7 @@ static constexpr OrtCompileApi ort_compile_api = { &OrtCompileAPI::ModelCompilationOptions_SetEpContextEmbedMode, &OrtCompileAPI::CompileModel, &OrtCompileAPI::ModelCompilationOptions_SetInputModel, + &OrtCompileAPI::ModelCompilationOptions_SetOutputModelStream, }; // checks that we don't violate the rule that the functions must remain in the slots they were originally assigned diff --git a/onnxruntime/core/session/compile_api.h b/onnxruntime/core/session/compile_api.h index 537d8580b8649..1cc63bc87e3ef 100644 --- a/onnxruntime/core/session/compile_api.h +++ b/onnxruntime/core/session/compile_api.h @@ -30,5 +30,7 @@ ORT_API_STATUS_IMPL(ModelCompilationOptions_SetEpContextEmbedMode, _In_ OrtModel ORT_API_STATUS_IMPL(CompileModel, _In_ const OrtEnv* env, _In_ const OrtModelCompilationOptions* model_options); ORT_API_STATUS_IMPL(ModelCompilationOptions_SetInputModel, _In_ OrtModelCompilationOptions* model_compile_options, _In_ const OrtModel* input_model); +ORT_API_STATUS_IMPL(ModelCompilationOptions_SetOutputModelStream, _In_ OrtModelCompilationOptions* model_compile_options, + _In_ WriteToStreamFunc write_stream_func, _In_ void* state); } // namespace OrtCompileAPI diff --git a/onnxruntime/core/session/model_compilation_options.cc b/onnxruntime/core/session/model_compilation_options.cc index 480946491e1ae..fae4e897ad0ef 100644 --- a/onnxruntime/core/session/model_compilation_options.cc +++ b/onnxruntime/core/session/model_compilation_options.cc @@ -9,6 +9,7 @@ #include #include "core/framework/allocator.h" +#include "core/framework/ep_context_options.h" #include "core/session/onnxruntime_session_options_config_keys.h" #include "core/session/environment.h" @@ -45,13 +46,12 @@ Status ModelCompilationOptions::SetOutputModelPath(const std::string& output_mod ORT_RETURN_IF_ERROR(ResetOutputModelSettings()); ConfigOptions& config_options = session_options_.value.config_options; - EpContextModelGenerationOptions& ep_context_gen_options = session_options_.value.ep_context_gen_options; + epctx::ModelGenOptions& ep_context_gen_options = session_options_.value.ep_context_gen_options; - ep_context_gen_options.output_model_file_path = output_model_path; + ep_context_gen_options.output_model_location = output_model_path; - if (ep_context_gen_options.output_model_file_path.size() <= ConfigOptions::kMaxValueLength) { - Status status = config_options.AddConfigEntry(kOrtSessionOptionEpContextFilePath, - ep_context_gen_options.output_model_file_path.c_str()); + if (output_model_path.size() <= ConfigOptions::kMaxValueLength) { + Status status = config_options.AddConfigEntry(kOrtSessionOptionEpContextFilePath, output_model_path.c_str()); ORT_ENFORCE(status.IsOK()); // Should not fail because both key/value strings are below the min string lengths // required by ConfigOptions::AddConfigEntry(). } else { @@ -72,7 +72,7 @@ Status ModelCompilationOptions::SetOutputModelPath(const std::string& output_mod logging::LoggingManager* log_manager = env_.GetLoggingManager(); if (log_manager != nullptr && log_manager->HasDefaultLogger()) { const logging::Logger& logger = log_manager->DefaultLogger(); - LOGS(logger, WARNING) << "Output model path length (" << ep_context_gen_options.output_model_file_path.size() + LOGS(logger, WARNING) << "Output model path length (" << output_model_path.size() << ") exceeds limit of " << ConfigOptions::kMaxKeyLength << " characters." << "ORT will still generated the expected output file, but EPs will see an empty " << "output model path in SessionOption's ConfigOptions."; @@ -93,12 +93,21 @@ Status ModelCompilationOptions::SetOutputModelBuffer(onnxruntime::AllocatorPtr a size_t* output_model_buffer_size_ptr) { ORT_RETURN_IF_ERROR(ResetOutputModelSettings()); - session_options_.value.ep_context_gen_options.output_model_buffer_ptr = output_model_buffer_ptr; - session_options_.value.ep_context_gen_options.output_model_buffer_size_ptr = output_model_buffer_size_ptr; - session_options_.value.ep_context_gen_options.output_model_buffer_allocator = std::move(allocator); + session_options_.value.ep_context_gen_options.output_model_location = epctx::BufferHolder{ + output_model_buffer_ptr, + output_model_buffer_size_ptr, + std::move(allocator), + }; return Status::OK(); } +void ModelCompilationOptions::SetOutputModelStream(WriteToStreamFunc write_stream_func, void* stream_state) { + session_options_.value.ep_context_gen_options.output_model_location = epctx::StreamHolder{ + write_stream_func, + stream_state, + }; +} + Status ModelCompilationOptions::SetEpContextEmbedMode(bool embed_ep_context_in_model) { ORT_RETURN_IF_ERROR(session_options_.value.config_options.AddConfigEntry( kOrtSessionOptionEpContextEmbedMode, embed_ep_context_in_model ? "1" : "0")); @@ -125,11 +134,6 @@ const OrtModel* ModelCompilationOptions::TryGetInputOrtModel() const { } Status ModelCompilationOptions::ResetOutputModelSettings() { - EpContextModelGenerationOptions& ep_context_gen_options = session_options_.value.ep_context_gen_options; - ep_context_gen_options.output_model_file_path.clear(); - ep_context_gen_options.output_model_buffer_ptr = nullptr; - ep_context_gen_options.output_model_buffer_size_ptr = nullptr; - ep_context_gen_options.output_model_buffer_allocator = nullptr; return session_options_.value.config_options.AddConfigEntry(kOrtSessionOptionEpContextFilePath, ""); } @@ -161,38 +165,47 @@ Status ModelCompilationOptions::Check() const { } // Check output model settings - const EpContextModelGenerationOptions& ep_context_gen_options = session_options_.value.ep_context_gen_options; + const epctx::ModelGenOptions& ep_context_gen_options = session_options_.value.ep_context_gen_options; - const bool explicit_output_to_file = !ep_context_gen_options.output_model_file_path.empty(); - const bool output_to_buffer = ep_context_gen_options.output_model_buffer_ptr != nullptr; + const bool has_no_output_model_location = std::holds_alternative( + ep_context_gen_options.output_model_location); - if (!explicit_output_to_file && !output_to_buffer && input_model_path_ptr != nullptr) { - // User did not specify an output file or an output buffer. We default to generating an output file + if (has_no_output_model_location && input_model_path_ptr != nullptr) { + // User did not specify an output file, output buffer, or output stream. We default to generating an output file // with a name based on the input file name, so do not return an error. return Status::OK(); } - if (!explicit_output_to_file && !output_to_buffer) { + if (has_no_output_model_location) { return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Unable to generate an output model path: require an input model path if the location " - "of the output model (e.g., file or buffer) is not specified."); + "of the output model (e.g., file, buffer, or stream) is not specified."); } - if (explicit_output_to_file && output_to_buffer) { + const epctx::BufferHolder* output_buffer_ptr = ep_context_gen_options.TryGetOutputModelBuffer(); + + if (output_buffer_ptr != nullptr && output_buffer_ptr->buffer_ptr == nullptr) { return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, - "Output model to compile must be saved either to a file or to a buffer, but not both."); + "Invalid buffer configuration for output model: buffer pointer is null"); } - if (output_to_buffer && ep_context_gen_options.output_model_buffer_size_ptr == nullptr) { + if (output_buffer_ptr != nullptr && output_buffer_ptr->buffer_size_ptr == nullptr) { return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Invalid buffer configuration for output model: size pointer is null"); } - if (output_to_buffer && ep_context_gen_options.output_model_buffer_allocator == nullptr) { + if (output_buffer_ptr != nullptr && output_buffer_ptr->buffer_allocator == nullptr) { return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Invalid buffer configuration for output model: allocator is null"); } + const epctx::StreamHolder* output_stream_ptr = ep_context_gen_options.TryGetOutputModelStream(); + + if (output_stream_ptr != nullptr && output_stream_ptr->write_func == nullptr) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "Invalid write-to-stream function for output model: function pointer is null"); + } + return Status::OK(); } } // namespace onnxruntime diff --git a/onnxruntime/core/session/model_compilation_options.h b/onnxruntime/core/session/model_compilation_options.h index f3f73ef262665..e2666fe8412f9 100644 --- a/onnxruntime/core/session/model_compilation_options.h +++ b/onnxruntime/core/session/model_compilation_options.h @@ -82,6 +82,13 @@ class ModelCompilationOptions { Status SetOutputModelBuffer(onnxruntime::AllocatorPtr allocator, void** output_model_buffer_ptr, size_t* output_model_buffer_size_ptr); + /// + /// Sets the function to write the output model to a stream. + /// + /// Write function + /// The stream state + void SetOutputModelStream(WriteToStreamFunc write_stream_func, void* stream_state); + /// /// Enables or disables the embedding of EPContext binary data into the `ep_cache_context` attribute of EPContext /// nodes. Defaults to false (dumped to file). diff --git a/onnxruntime/core/session/utils.cc b/onnxruntime/core/session/utils.cc index 46f5497ca2035..c550640e60c63 100644 --- a/onnxruntime/core/session/utils.cc +++ b/onnxruntime/core/session/utils.cc @@ -139,13 +139,11 @@ static OrtStatus* CreateSessionAndLoadModelImpl(_In_ const OrtSessionOptions* op // If ep.context_enable is set, then ep.context_file_path is expected, otherwise ORT don't know where to generate the _ctx.onnx file if (options && model_path == nullptr) { - EpContextModelGenerationOptions ep_ctx_gen_options = options->value.GetEpContextGenerationOptions(); + epctx::ModelGenOptions ep_ctx_gen_options = options->value.GetEpContextGenerationOptions(); // This is checked by the OrtCompileApi's CompileModel() function, but we check again here in case // the user used the older SessionOptions' configuration entries to generate a compiled model. - if (ep_ctx_gen_options.enable && - ep_ctx_gen_options.output_model_file_path.empty() && - ep_ctx_gen_options.output_model_buffer_ptr == nullptr) { + if (ep_ctx_gen_options.enable && !ep_ctx_gen_options.HasOutputModelLocation()) { return OrtApis::CreateStatus(ORT_FAIL, "Inference session was configured with EPContext model generation enabled but " "without a valid location (e.g., file or buffer) for the output model. " diff --git a/onnxruntime/test/framework/session_state_test.cc b/onnxruntime/test/framework/session_state_test.cc index 76399743c97f8..3b90cfcce76db 100644 --- a/onnxruntime/test/framework/session_state_test.cc +++ b/onnxruntime/test/framework/session_state_test.cc @@ -11,6 +11,7 @@ #include "core/framework/kernel_registry.h" #include "core/framework/op_kernel.h" #include "core/framework/bfc_arena.h" +#include "core/framework/ep_context_options.h" #include "core/framework/session_state.h" #include "core/graph/graph_utils.h" #include "core/graph/graph_viewer.h" @@ -504,7 +505,7 @@ void LoadWithResourceAwarePartitioning(const ORTCHAR_T* model_path, ASSERT_STATUS_OK( partitioner.Partition(graph, session_state.GetMutableFuncMgr(), transform_layout_fn, sess_options.config_options, default_logger, GraphPartitioner::Mode::kNormal, - EpContextModelGenerationOptions{}, + epctx::ModelGenOptions{}, debug_graph_fn)); verifier_fn(graph); From 7f719a9b984c9b6789b4a1ca6958b536c167ea25 Mon Sep 17 00:00:00 2001 From: adrianlizarraga Date: Wed, 14 May 2025 02:52:41 -0700 Subject: [PATCH 06/14] Serialize proto to stream. need test --- .../core/session/onnxruntime_c_api.h | 29 +++--- .../core/framework/ep_context_options.cc | 94 ++++++++++++++++++- .../core/framework/ep_context_options.h | 66 +++++++------ .../core/framework/graph_partitioner.cc | 23 ++++- onnxruntime/core/session/compile_api.cc | 8 +- onnxruntime/core/session/compile_api.h | 4 +- .../core/session/model_compilation_options.cc | 6 +- .../core/session/model_compilation_options.h | 4 +- 8 files changed, 174 insertions(+), 60 deletions(-) diff --git a/include/onnxruntime/core/session/onnxruntime_c_api.h b/include/onnxruntime/core/session/onnxruntime_c_api.h index 38489ad7618e2..16f804a2a3d34 100644 --- a/include/onnxruntime/core/session/onnxruntime_c_api.h +++ b/include/onnxruntime/core/session/onnxruntime_c_api.h @@ -455,21 +455,19 @@ typedef OrtStatus*(ORT_API_CALL* EpSelectionDelegate)(_In_ const OrtEpDevice** e /** \brief Function that writes a buffer to a stream. * - * \param state Opaque pointer holding the state for the user's stream. + * \param steam_state Opaque pointer holding the state for the user's stream. * \param buffer The buffer to write to the stream. * \param buffer_num_bytes The size of the buffer in bytes. - * \param num_bytes_written Output parameter that should be set to the number of bytes written to - * the stream. ONNX Runtime will continuously call this write function until - * all bytes have been written to the stream. + * \param num_bytes_written Output parameter that should be set to the number of bytes written to the stream. * * \return OrtStatus* Write status. Return nullptr on success. * Use CreateStatus to provide error info. Use ORT_FAIL as the error code. * ORT will release the OrtStatus* if not null. */ -typedef OrtStatus*(ORT_API_CALL* WriteToStreamFunc)(_In_ void* state, - _In_ const void* buffer, - _In_ size_t buffer_num_bytes, - _Out_ size_t* num_bytes_written); +typedef OrtStatus*(ORT_API_CALL* OrtOutStreamWriteFunc)(_In_ void* stream_state, + _In_ const void* buffer, + _In_ size_t buffer_num_bytes, + _Out_ size_t* num_bytes_written); /** \brief Algorithm to use for cuDNN Convolution Op */ @@ -5995,18 +5993,21 @@ struct OrtCompileApi { ORT_API2_STATUS(ModelCompilationOptions_SetInputModel, _In_ OrtModelCompilationOptions* model_compile_options, _In_ const OrtModel* input_model); - /** \brief Sets the WriteToStreamFunc function that ONNX Runtime should call to write the output model to a stream. + /** \brief Sets an output stream used to write out the output model's serialized ONNX bytes. + * + * The write function is called called repeatedly until then entire output model has been written out. * * \param[in] model_compile_options The OrtModelCompilationOptions instance. - * \param[in] write_stream_func The WriteToStreamFunc function to call when writing out the model. - * \param[in] state Opaque state passed as the first argument to WriteToStreamFunc. + * \param[in] write_stream_func The OrtOutStreamWriteFunc function to call when writing out the model. + * \param[in] state Opaque stream state passed as the first argument to WriteToStreamFunc. * * \snippet{doc} snippets.dox OrtStatus Return Value * - * \since Version 1.22. + * \since Version 1.23. */ - ORT_API2_STATUS(ModelCompilationOptions_SetOutputModelStream, _In_ OrtModelCompilationOptions* model_compile_options, - _In_ WriteToStreamFunc write_stream_func, _In_ void* state); + ORT_API2_STATUS(ModelCompilationOptions_SetOutputModelOutStream, + _In_ OrtModelCompilationOptions* model_compile_options, + _In_ OrtOutStreamWriteFunc write_stream_func, _In_ void* stream_state); }; ORT_RUNTIME_CLASS(Ep); diff --git a/onnxruntime/core/framework/ep_context_options.cc b/onnxruntime/core/framework/ep_context_options.cc index c22fda550f34e..70256026c6328 100644 --- a/onnxruntime/core/framework/ep_context_options.cc +++ b/onnxruntime/core/framework/ep_context_options.cc @@ -1,14 +1,26 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. +#include +#include "core/common/common.h" #include "core/framework/ep_context_options.h" +#include "core/framework/error_code_helper.h" #include "core/session/onnxruntime_session_options_config_keys.h" namespace onnxruntime { namespace epctx { +// class ModelGenOptions + ModelGenOptions::ModelGenOptions(const ConfigOptions& config_options) { enable = config_options.GetConfigOrDefault(kOrtSessionOptionEpContextEnable, "0") == "1"; - output_model_location = config_options.GetConfigOrDefault(kOrtSessionOptionEpContextFilePath, ""); + + std::string output_model_path = config_options.GetConfigOrDefault(kOrtSessionOptionEpContextFilePath, ""); + if (!output_model_path.empty()) { + output_model_location = config_options.GetConfigOrDefault(kOrtSessionOptionEpContextFilePath, ""); + } else { + output_model_location = std::monostate{}; + } + output_external_initializers_file_path = config_options.GetConfigOrDefault( kOrtSessionOptionsEpContextModelExternalInitializersFileName, ""); output_external_initializer_size_threshold = 0; @@ -27,8 +39,84 @@ const BufferHolder* ModelGenOptions::TryGetOutputModelBuffer() const { return std::get_if(&output_model_location); } -const StreamHolder* ModelGenOptions::TryGetOutputModelStream() const { - return std::get_if(&output_model_location); +const OutStreamHolder* ModelGenOptions::TryGetOutputModelOutStream() const { + return std::get_if(&output_model_location); +} + +// class OutStreamBuf + +OutStreamBuf::OutStreamBuf(OutStreamHolder out_stream_holder) : out_stream_holder_(out_stream_holder) { + setp(buffer_.data(), buffer_.data() + buffer_.size() - 1); // Leave room for overflow character +} + +OutStreamBuf::~OutStreamBuf() { + sync(); +} + +std::streambuf::int_type OutStreamBuf::overflow(std::streambuf::int_type ch) { + if (ch != traits_type::eof()) { + *pptr() = static_cast(ch); + pbump(1); + } + + if (FlushBuffer() == -1) { + return traits_type::eof(); + } + + return ch; +} + +int OutStreamBuf::sync() { + return FlushBuffer(); +} + +int OutStreamBuf::FlushBuffer() { + std::ptrdiff_t num_bytes = pptr() - pbase(); + if (num_bytes == 0) { + return 0; + } + + // Can only call pbump() with an int, so can only write at most 2^31 - 1. + if (num_bytes > std::numeric_limits::max()) { + num_bytes = std::numeric_limits::max(); + } + + std::ptrdiff_t bytes_remaining = num_bytes; + char* ptr = pbase(); + + while (bytes_remaining > 0) { + size_t bytes_written = 0; + Status status = Status::OK(); + + ORT_TRY { + status = ToStatus(out_stream_holder_.write_func(out_stream_holder_.stream_state, + ptr, bytes_remaining, &bytes_written)); + } + ORT_CATCH(const std::exception& e) { + ORT_HANDLE_EXCEPTION([&]() { + status = ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, + "Caught exception while calling user's OrtOutStreamWriteFunc callback: ", e.what()); + }); + } + + if (!status.IsOK()) { + last_status_ = std::move(status); + return -1; + } + + if (bytes_written > static_cast(bytes_remaining)) { + last_status_ = ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "OrtOutStreamWriteFunc wrote more bytes (", bytes_written, + ") than requested (", bytes_remaining, ")."); + return -1; + } + + bytes_remaining -= static_cast(bytes_written); + ptr += bytes_written; + } + + assert(ptr == pptr()); + pbump(-static_cast(num_bytes)); // Reset internal pointer to point to the beginning of the buffer_ + return 0; } } // namespace epctx diff --git a/onnxruntime/core/framework/ep_context_options.h b/onnxruntime/core/framework/ep_context_options.h index 6ab6031a6a1e4..1cd9ef520b0cb 100644 --- a/onnxruntime/core/framework/ep_context_options.h +++ b/onnxruntime/core/framework/ep_context_options.h @@ -3,6 +3,8 @@ #pragma once +#include +#include #include #include #include "core/framework/allocator.h" @@ -16,29 +18,11 @@ struct BufferHolder { AllocatorPtr buffer_allocator = nullptr; }; -struct StreamHolder { - WriteToStreamFunc write_func; - void* state; // Opaque pointer to user's stream state. Passed as first argument to write_func. +struct OutStreamHolder { + OrtOutStreamWriteFunc write_func = nullptr; + void* stream_state = nullptr; // Opaque pointer to user's stream state. Passed as first argument to write_func. }; -/* -class WriteFuncStreamBuf : public std::streambuf { - public: - WriteFuncStreamBuf(StreamHolder write_func_holder); - ~WriteFuncStreamBuf(); - - protected: - int_type overflow(int_type ch) override; - int sync() override; - - private: - int FlushBuffer(); - - StreamHolder write_func_holder_; - std::array buffer_; -}; -*/ - struct ModelGenOptions { ModelGenOptions() = default; @@ -50,11 +34,11 @@ struct ModelGenOptions { bool error_if_no_compiled_nodes = false; bool embed_ep_context_in_model = false; - std::variant // Function to write the output model to a user's stream. - output_model_location; + std::variant // Function to write the output model to a user's stream. + output_model_location{}; std::string output_external_initializers_file_path; size_t output_external_initializer_size_threshold = 0; @@ -62,7 +46,35 @@ struct ModelGenOptions { bool HasOutputModelLocation() const; const std::string* TryGetOutputModelPath() const; const BufferHolder* TryGetOutputModelBuffer() const; - const StreamHolder* TryGetOutputModelStream() const; + const OutStreamHolder* TryGetOutputModelOutStream() const; +}; + +// Class that wraps the user's OrtOutStreamWriteFunc function to enable use with +// C++'s std::ostream. +// Example: +// OutStreamHolder stream_holder{write_func, stream_state}; +// std::unique_ptr out_stream_buf = std::make_unique(stream_holder); +// std::ostream out_stream(out_stream_buf.get()); +class OutStreamBuf : public std::streambuf { + public: + OutStreamBuf(OutStreamHolder out_stream_holder); + ~OutStreamBuf(); + + Status GetStatus() const { + return last_status_; + } + + protected: + int_type overflow(int_type ch) override; + int sync() override; + + private: + int FlushBuffer(); + + OutStreamHolder out_stream_holder_{}; + std::array buffer_{}; + Status last_status_{}; }; + } // namespace epctx } // namespace onnxruntime diff --git a/onnxruntime/core/framework/graph_partitioner.cc b/onnxruntime/core/framework/graph_partitioner.cc index 280a92092793a..7a53563babddf 100644 --- a/onnxruntime/core/framework/graph_partitioner.cc +++ b/onnxruntime/core/framework/graph_partitioner.cc @@ -825,8 +825,8 @@ static Status CreateEpContextModel(const ExecutionProviders& execution_providers return std::make_pair(false, static_cast(nullptr)); }; - const auto* output_buffer_holder = ep_context_gen_options.TryGetOutputModelBuffer(); - const auto* output_stream_holder = ep_context_gen_options.TryGetOutputModelStream(); + const epctx::BufferHolder* output_buffer_holder = ep_context_gen_options.TryGetOutputModelBuffer(); + const epctx::OutStreamHolder* output_stream_holder = ep_context_gen_options.TryGetOutputModelOutStream(); const std::string* output_model_path_ptr = ep_context_gen_options.TryGetOutputModelPath(); std::filesystem::path valid_output_model_path; @@ -897,9 +897,8 @@ static Status CreateEpContextModel(const ExecutionProviders& execution_providers ModelSavingOptions model_saving_options{ini_size_threshold}; if (output_buffer_holder != nullptr) { + // Write output model into a buffer ORT allocates for the user. ORT_RETURN_IF_ERROR(ep_context_model.MainGraph().Resolve()); - // TODO(adrianlizarraga): Investigate if we can make this more memory efficient. - // May be able to use allocator to directly allocate the ModelProto to avoid a copy. ONNX_NAMESPACE::ModelProto model_proto = ep_context_model.ToGraphProtoWithExternalInitializers(external_ini_path, valid_output_model_path, model_saving_options); @@ -914,8 +913,22 @@ static Status CreateEpContextModel(const ExecutionProviders& execution_providers *output_buffer_holder->buffer_size_ptr = buffer_size; *output_buffer_holder->buffer_ptr = buffer.release(); } else if (output_stream_holder != nullptr) { - // TODO + // Write output model to user's output stream. + ORT_RETURN_IF_ERROR(ep_context_model.MainGraph().Resolve()); + ONNX_NAMESPACE::ModelProto model_proto = ep_context_model.ToGraphProtoWithExternalInitializers(external_ini_path, + valid_output_model_path, + model_saving_options); + size_t buffer_size = model_proto.ByteSizeLong(); + ORT_RETURN_IF(buffer_size > static_cast(std::numeric_limits::max()), + "Cannot serialize ONNX ModelProto larger than 2GB"); + + auto out_stream_buf = std::make_unique(*output_stream_holder); + std::ostream out_stream(out_stream_buf.get()); + + model_proto.SerializeToOstream(&out_stream); + ORT_RETURN_IF_ERROR(out_stream_buf->GetStatus()); } else { + // Write output model to file. ORT_RETURN_IF_ERROR(Model::SaveWithExternalInitializers(ep_context_model, valid_output_model_path, external_ini_path, model_saving_options)); } diff --git a/onnxruntime/core/session/compile_api.cc b/onnxruntime/core/session/compile_api.cc index d571edd7647d6..fef3ffd45796a 100644 --- a/onnxruntime/core/session/compile_api.cc +++ b/onnxruntime/core/session/compile_api.cc @@ -207,9 +207,9 @@ ORT_API_STATUS_IMPL(OrtCompileAPI::ModelCompilationOptions_SetOutputModelBuffer, API_IMPL_END } -ORT_API_STATUS_IMPL(OrtCompileAPI::ModelCompilationOptions_SetOutputModelStream, +ORT_API_STATUS_IMPL(OrtCompileAPI::ModelCompilationOptions_SetOutputModelOutStream, _In_ OrtModelCompilationOptions* ort_model_compile_options, - _In_ WriteToStreamFunc write_stream_func, _In_ void* state) { + _In_ OrtOutStreamWriteFunc write_stream_func, _In_ void* state) { API_IMPL_BEGIN #if !defined(ORT_MINIMAL_BUILD) auto model_compile_options = reinterpret_cast(ort_model_compile_options); @@ -218,7 +218,7 @@ ORT_API_STATUS_IMPL(OrtCompileAPI::ModelCompilationOptions_SetOutputModelStream, return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, "WriteToStreamFunc function for output model is null"); } - model_compile_options->SetOutputModelStream(write_stream_func, state); + model_compile_options->SetOutputModelOutStream(write_stream_func, state); return nullptr; #else ORT_UNUSED_PARAMETER(ort_model_compile_options); @@ -274,7 +274,7 @@ static constexpr OrtCompileApi ort_compile_api = { &OrtCompileAPI::ModelCompilationOptions_SetEpContextEmbedMode, &OrtCompileAPI::CompileModel, &OrtCompileAPI::ModelCompilationOptions_SetInputModel, - &OrtCompileAPI::ModelCompilationOptions_SetOutputModelStream, + &OrtCompileAPI::ModelCompilationOptions_SetOutputModelOutStream, }; // checks that we don't violate the rule that the functions must remain in the slots they were originally assigned diff --git a/onnxruntime/core/session/compile_api.h b/onnxruntime/core/session/compile_api.h index 1cc63bc87e3ef..1be772e800ac2 100644 --- a/onnxruntime/core/session/compile_api.h +++ b/onnxruntime/core/session/compile_api.h @@ -30,7 +30,7 @@ ORT_API_STATUS_IMPL(ModelCompilationOptions_SetEpContextEmbedMode, _In_ OrtModel ORT_API_STATUS_IMPL(CompileModel, _In_ const OrtEnv* env, _In_ const OrtModelCompilationOptions* model_options); ORT_API_STATUS_IMPL(ModelCompilationOptions_SetInputModel, _In_ OrtModelCompilationOptions* model_compile_options, _In_ const OrtModel* input_model); -ORT_API_STATUS_IMPL(ModelCompilationOptions_SetOutputModelStream, _In_ OrtModelCompilationOptions* model_compile_options, - _In_ WriteToStreamFunc write_stream_func, _In_ void* state); +ORT_API_STATUS_IMPL(ModelCompilationOptions_SetOutputModelOutStream, _In_ OrtModelCompilationOptions* model_compile_options, + _In_ OrtOutStreamWriteFunc write_stream_func, _In_ void* state); } // namespace OrtCompileAPI diff --git a/onnxruntime/core/session/model_compilation_options.cc b/onnxruntime/core/session/model_compilation_options.cc index fae4e897ad0ef..eed87a69e2071 100644 --- a/onnxruntime/core/session/model_compilation_options.cc +++ b/onnxruntime/core/session/model_compilation_options.cc @@ -101,8 +101,8 @@ Status ModelCompilationOptions::SetOutputModelBuffer(onnxruntime::AllocatorPtr a return Status::OK(); } -void ModelCompilationOptions::SetOutputModelStream(WriteToStreamFunc write_stream_func, void* stream_state) { - session_options_.value.ep_context_gen_options.output_model_location = epctx::StreamHolder{ +void ModelCompilationOptions::SetOutputModelOutStream(OrtOutStreamWriteFunc write_stream_func, void* stream_state) { + session_options_.value.ep_context_gen_options.output_model_location = epctx::OutStreamHolder{ write_stream_func, stream_state, }; @@ -199,7 +199,7 @@ Status ModelCompilationOptions::Check() const { "Invalid buffer configuration for output model: allocator is null"); } - const epctx::StreamHolder* output_stream_ptr = ep_context_gen_options.TryGetOutputModelStream(); + const epctx::OutStreamHolder* output_stream_ptr = ep_context_gen_options.TryGetOutputModelOutStream(); if (output_stream_ptr != nullptr && output_stream_ptr->write_func == nullptr) { return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, diff --git a/onnxruntime/core/session/model_compilation_options.h b/onnxruntime/core/session/model_compilation_options.h index e2666fe8412f9..a220878b37cab 100644 --- a/onnxruntime/core/session/model_compilation_options.h +++ b/onnxruntime/core/session/model_compilation_options.h @@ -83,11 +83,11 @@ class ModelCompilationOptions { size_t* output_model_buffer_size_ptr); /// - /// Sets the function to write the output model to a stream. + /// Sets an output stream (write function + state) used to write out the compiled model. /// /// Write function /// The stream state - void SetOutputModelStream(WriteToStreamFunc write_stream_func, void* stream_state); + void SetOutputModelOutStream(OrtOutStreamWriteFunc write_stream_func, void* stream_state); /// /// Enables or disables the embedding of EPContext binary data into the `ep_cache_context` attribute of EPContext From 72eeece674f4fbdcd51c0ebd5cd53e20433b4985 Mon Sep 17 00:00:00 2001 From: adrianlizarraga Date: Wed, 14 May 2025 10:40:03 -0700 Subject: [PATCH 07/14] Add unit test for compiling to user's write stream --- .../core/session/onnxruntime_cxx_api.h | 3 +- .../core/session/onnxruntime_cxx_inline.h | 8 ++ .../core/framework/ep_context_options.h | 9 ++ .../test/providers/qnn/qnn_ep_context_test.cc | 115 +++++++++++++++++- 4 files changed, 130 insertions(+), 5 deletions(-) diff --git a/include/onnxruntime/core/session/onnxruntime_cxx_api.h b/include/onnxruntime/core/session/onnxruntime_cxx_api.h index f21c092efb220..cebf1404b631a 100644 --- a/include/onnxruntime/core/session/onnxruntime_cxx_api.h +++ b/include/onnxruntime/core/session/onnxruntime_cxx_api.h @@ -2864,7 +2864,8 @@ struct ModelCompilationOptions : detail::Base { ModelCompilationOptions& SetOutputModelExternalInitializersFile(const ORTCHAR_T* file_path, size_t initializer_size_threshold); ///< Wraps OrtApi::ModelCompilationOptions_SetOutputModelExternalInitializersFile ModelCompilationOptions& SetOutputModelBuffer(OrtAllocator* allocator, void** output_model_buffer_ptr, - size_t* output_model_buffer_size_ptr); ///< Wraps OrtApi::ModelCompilationOptions_SetOutputModelBuffer + size_t* output_model_buffer_size_ptr); ///< Wraps OrtApi::ModelCompilationOptions_SetOutputModelBuffer + ModelCompilationOptions& SetOutputModelOutStream(OrtOutStreamWriteFunc write_stream_func, void* stream_state); ///< Wraps OrtApi::ModelCompilationOptions_SetOutputModelOutStream }; /** \brief Compiles an input model to generate a model with EPContext nodes that execute EP-specific kernels. Wraps OrtApi::CompileModels. diff --git a/include/onnxruntime/core/session/onnxruntime_cxx_inline.h b/include/onnxruntime/core/session/onnxruntime_cxx_inline.h index 66bfe6188798b..883e29bfd9a37 100644 --- a/include/onnxruntime/core/session/onnxruntime_cxx_inline.h +++ b/include/onnxruntime/core/session/onnxruntime_cxx_inline.h @@ -830,6 +830,14 @@ inline ModelCompilationOptions& ModelCompilationOptions::SetOutputModelBuffer( return *this; } +inline ModelCompilationOptions& ModelCompilationOptions::SetOutputModelOutStream( + OrtOutStreamWriteFunc write_stream_func, void* stream_state) { + Ort::ThrowOnError(GetCompileApi().ModelCompilationOptions_SetOutputModelOutStream(this->p_, + write_stream_func, + stream_state)); + return *this; +} + inline ModelCompilationOptions& ModelCompilationOptions::SetEpContextEmbedMode( bool embed_ep_context_in_model) { Ort::ThrowOnError(GetCompileApi().ModelCompilationOptions_SetEpContextEmbedMode( diff --git a/onnxruntime/core/framework/ep_context_options.h b/onnxruntime/core/framework/ep_context_options.h index 1cd9ef520b0cb..bb59aec4586ce 100644 --- a/onnxruntime/core/framework/ep_context_options.h +++ b/onnxruntime/core/framework/ep_context_options.h @@ -12,17 +12,26 @@ namespace onnxruntime { namespace epctx { +/// +/// Holds the buffer that will store the output model and the allocator used to allocate the memory. +/// struct BufferHolder { void** buffer_ptr = nullptr; size_t* buffer_size_ptr = nullptr; AllocatorPtr buffer_allocator = nullptr; }; +/// +/// Holds the opaque stream state and the write function that ORT calls to write out the output model. +/// struct OutStreamHolder { OrtOutStreamWriteFunc write_func = nullptr; void* stream_state = nullptr; // Opaque pointer to user's stream state. Passed as first argument to write_func. }; +/// +/// Stores EPContext model generation options. Used in SessionOptions. +/// struct ModelGenOptions { ModelGenOptions() = default; diff --git a/onnxruntime/test/providers/qnn/qnn_ep_context_test.cc b/onnxruntime/test/providers/qnn/qnn_ep_context_test.cc index b095d6e7b3c4b..27e635ef29f7a 100644 --- a/onnxruntime/test/providers/qnn/qnn_ep_context_test.cc +++ b/onnxruntime/test/providers/qnn/qnn_ep_context_test.cc @@ -2,6 +2,7 @@ // Licensed under the MIT License. #include +#include #include #include "core/session/onnxruntime_cxx_api.h" @@ -557,8 +558,45 @@ TEST_F(QnnHTPBackendTests, CompileApi_FromSessionOptions_OutputModelBuffer_Outpu allocator.Free(output_model_buffer); } -// Tests compiling an OrtModel created using the OrtModelEditor API. -TEST_F(QnnHTPBackendTests, CompileApi_InputOrtModel_OutputFile) { +// Implementation of OrtOutStreamWriteFunc that writes out model to a file. +static OrtStatus* ORT_API_CALL TestWriteToStream(void* stream_state, const void* buffer, size_t buffer_num_bytes, + size_t* num_bytes_written) { + std::ofstream* outfile = reinterpret_cast(stream_state); + + // Write out file in chunks of at most 2048 bytes to test multiple calls to this function. + size_t write_amount = std::min(static_cast(2048), buffer_num_bytes); + outfile->write(reinterpret_cast(buffer), write_amount); + *num_bytes_written = write_amount; + + return nullptr; +} + +// Implementation of OrtOutStreamWriteFunc that writes too much, resulting in an error. +static OrtStatus* ORT_API_CALL WriteTooMuchToStream(void* stream_state, const void* buffer, size_t buffer_num_bytes, + size_t* num_bytes_written) { + ORT_UNUSED_PARAMETER(stream_state); + ORT_UNUSED_PARAMETER(buffer); + + // Incorrectly say we wrote more than requested. ORT should return an error on call to CompileModel(). + *num_bytes_written = buffer_num_bytes + 1; + return nullptr; +} + +// Implementation of OrtOutStreamWriteFunc that directly returns an OrtStatus indicating an error. +static OrtStatus* ORT_API_CALL ReturnStatusFromStream(void* stream_state, const void* buffer, size_t buffer_num_bytes, + size_t* num_bytes_written) { + ORT_UNUSED_PARAMETER(stream_state); + ORT_UNUSED_PARAMETER(buffer); + ORT_UNUSED_PARAMETER(buffer_num_bytes); + + *num_bytes_written = 0; + return Ort::GetApi().CreateStatus(ORT_FAIL, "Error from OrtOutStreamWriteFunc callback"); +} + +// Test using the CompileModel() API with settings: +// - input OrtModel created via OrtModelEditor API +// - write output model to custom stream +TEST_F(QnnHTPBackendTests, CompileApi_InputOrtModel_OutputToStream) { std::vector>> weights; // Model weights must remain valid through inference // if we want to avoid a copy. @@ -607,18 +645,87 @@ TEST_F(QnnHTPBackendTests, CompileApi_InputOrtModel_OutputFile) { const ORTCHAR_T* output_model_file = ORT_TSTR("compileapi_ortmodel_ctx.onnx"); std::filesystem::remove(output_model_file); + // Open an output file. Test will incrementally write the output model to file + // via calls to our OrtOutStreamWriteFunc callback. + ASSERT_FALSE(std::filesystem::exists(output_model_file)); + std::ofstream outfile(output_model_file, std::ios::binary); + // Create model compilation options from the session options. Ort::ModelCompilationOptions compile_options(*ort_env, so); compile_options.SetInputModel(model.GetConst()); - compile_options.SetOutputModelPath(output_model_file); + compile_options.SetOutputModelOutStream(TestWriteToStream, reinterpret_cast(&outfile)); // Set output stream compile_options.SetEpContextEmbedMode(true); // Compile the model. Ort::Status status = Ort::CompileModel(*ort_env, compile_options); ASSERT_TRUE(status.IsOK()) << status.GetErrorMessage(); + outfile.flush(); + outfile.close(); - // Make sure the compiled model was generated and has the expected number of EPContext nodes. + // Make sure the compiled model was generated (via our write stream function) + // and has the expected number of EPContext nodes. ASSERT_TRUE(std::filesystem::exists(output_model_file)); + CheckEpContextNodeCounts(output_model_file, 1, 0); +} + +// Tests using an OrtOutStreamFunc function that writes too much. +TEST_F(QnnHTPBackendTests, CompileApi_OutputStream_WriteTooMuch) { + const ORTCHAR_T* input_model_file = ORT_TSTR("./compileapi_outputstream_writetoomuch.onnx"); + std::filesystem::remove(input_model_file); + + // Create a test model and save it to a file. + TestModel test_model; + CreateTestModel(BuildGraphWithQAndNonQ(false), 21, logging::Severity::kERROR, test_model); + ASSERT_STATUS_OK(test_model.Save(input_model_file)); + + // Initialize session options with QNN EP + Ort::SessionOptions so; + ProviderOptions provider_options; + provider_options["backend_type"] = "htp"; + provider_options["offload_graph_io_quantization"] = "0"; + so.AppendExecutionProvider("QNN", provider_options); + + // Create model compilation options from the session options. + Ort::ModelCompilationOptions compile_options(*ort_env, so); + compile_options.SetInputModelPath(input_model_file); + compile_options.SetOutputModelOutStream(WriteTooMuchToStream, nullptr); // Set output stream that writes too much + compile_options.SetEpContextEmbedMode(true); + + // Compile the model. Expect an error status because our stream wrote too much. + Ort::Status status = Ort::CompileModel(*ort_env, compile_options); + ASSERT_FALSE(status.IsOK()); + EXPECT_EQ(status.GetErrorCode(), ORT_FAIL); + EXPECT_TRUE(status.GetErrorMessage().find("OrtOutStreamWriteFunc wrote more bytes") != std::string::npos); +} + +// Tests using an OrtOutStreamFunc function that returns an error. +TEST_F(QnnHTPBackendTests, CompileApi_OutputStream_ReturnStatus) { + const ORTCHAR_T* input_model_file = ORT_TSTR("./compileapi_outputstream_returnstatus.onnx"); + std::filesystem::remove(input_model_file); + + // Create a test model and save it to a file. + TestModel test_model; + CreateTestModel(BuildGraphWithQAndNonQ(false), 21, logging::Severity::kERROR, test_model); + ASSERT_STATUS_OK(test_model.Save(input_model_file)); + + // Initialize session options with QNN EP + Ort::SessionOptions so; + ProviderOptions provider_options; + provider_options["backend_type"] = "htp"; + provider_options["offload_graph_io_quantization"] = "0"; + so.AppendExecutionProvider("QNN", provider_options); + + // Create model compilation options from the session options. + Ort::ModelCompilationOptions compile_options(*ort_env, so); + compile_options.SetInputModelPath(input_model_file); + compile_options.SetOutputModelOutStream(ReturnStatusFromStream, nullptr); // Set output stream that returns error + compile_options.SetEpContextEmbedMode(true); + + // Compile the model. Expect a specific error status. + Ort::Status status = Ort::CompileModel(*ort_env, compile_options); + ASSERT_FALSE(status.IsOK()); + EXPECT_EQ(status.GetErrorCode(), ORT_FAIL); + EXPECT_EQ(status.GetErrorMessage(), "Error from OrtOutStreamWriteFunc callback"); } // Test that models with 1 non-quantized FusedMatMul node and 1 quantized Add node can still generate the context binary From 2043a36cab4e2e6dbc23381b410c093bb6108080 Mon Sep 17 00:00:00 2001 From: adrianlizarraga Date: Wed, 14 May 2025 11:29:03 -0700 Subject: [PATCH 08/14] Clean up --- .../core/session/onnxruntime_c_api.h | 26 ++++++++++++------- .../core/framework/ep_context_options.cc | 14 ++++------ .../core/framework/ep_context_options.h | 2 -- onnxruntime/core/session/compile_api.cc | 6 ++--- onnxruntime/core/session/compile_api.h | 2 +- .../test/providers/qnn/qnn_ep_context_test.cc | 4 +-- 6 files changed, 28 insertions(+), 26 deletions(-) diff --git a/include/onnxruntime/core/session/onnxruntime_c_api.h b/include/onnxruntime/core/session/onnxruntime_c_api.h index 16f804a2a3d34..f57b535a4230e 100644 --- a/include/onnxruntime/core/session/onnxruntime_c_api.h +++ b/include/onnxruntime/core/session/onnxruntime_c_api.h @@ -5880,10 +5880,10 @@ struct OrtCompileApi { /** \brief Sets the file path for the output ONNX model generated by CompileModel. * - * The output model's location (e.g., file path or memory buffer) can be set with either - * ModelCompilationOptions_SetOutputModelPath or ModelCompilationOptions_SetOutputModelBuffer. + * The output model's destination (e.g., file path, memory buffer, or stream) can be set with any of the functions + * that begin with ModelCompilationOptions_SetOutputModel____. * - * If the output model's location is not set, ONNX Runtime will generate an output file with a path based on + * If the output model's destination is not set, ONNX Runtime will generate an output file with a path based on * the input model's file path. Examples: * /Path/my_model.onnx -> /Path/my_model_ctx.onnx * /Path/my_model -> /Path/my_model_ctx.onnx @@ -5922,10 +5922,10 @@ struct OrtCompileApi { * * The caller passes an OrtAllocator that ONNX Runtime uses to allocate memory for the buffer. * - * The output model's location (e.g., file path or memory buffer) can be set with either - * ModelCompilationOptions_SetOutputModelPath or ModelCompilationOptions_SetOutputModelBuffer. + * The output model's destination (e.g., file path, memory buffer, or stream) can be set with any of the functions + * that begin with ModelCompilationOptions_SetOutputModel____. * - * If the output model's location is not set, ONNX Runtime will generate an output file with a path based on + * If the output model's destination is not set, ONNX Runtime will generate an output file with a path based on * the input model's file path. Examples: * /Path/my_model.onnx -> /Path/my_model_ctx.onnx * /Path/my_model -> /Path/my_model_ctx.onnx @@ -5984,7 +5984,7 @@ struct OrtCompileApi { /** \brief Sets the input OrtModel instance to compile. * * \param[in] model_compile_options The OrtModelCompilationOptions instance. - * \param[in] input_model The OrtModel instance of the model to compile. + * \param[in] input_model The OrtModel instance to compile. * * \snippet{doc} snippets.dox OrtStatus Return Value * @@ -5995,11 +5995,19 @@ struct OrtCompileApi { /** \brief Sets an output stream used to write out the output model's serialized ONNX bytes. * - * The write function is called called repeatedly until then entire output model has been written out. + * The write function is called repeatedly until then entire output model has been written out. + * + * The output model's destination (e.g., file path, memory buffer, or stream) can be set with any of the functions + * that begin with ModelCompilationOptions_SetOutputModel____. + * + * If the output model's destination is not set, ONNX Runtime will generate an output file with a path based on + * the input model's file path. Examples: + * /Path/my_model.onnx -> /Path/my_model_ctx.onnx + * /Path/my_model -> /Path/my_model_ctx.onnx * * \param[in] model_compile_options The OrtModelCompilationOptions instance. * \param[in] write_stream_func The OrtOutStreamWriteFunc function to call when writing out the model. - * \param[in] state Opaque stream state passed as the first argument to WriteToStreamFunc. + * \param[in] state Opaque stream state passed as the first argument to OrtOutStreamWriteFunc. Can be null. * * \snippet{doc} snippets.dox OrtStatus Return Value * diff --git a/onnxruntime/core/framework/ep_context_options.cc b/onnxruntime/core/framework/ep_context_options.cc index 70256026c6328..df2f3c6a42456 100644 --- a/onnxruntime/core/framework/ep_context_options.cc +++ b/onnxruntime/core/framework/ep_context_options.cc @@ -46,7 +46,7 @@ const OutStreamHolder* ModelGenOptions::TryGetOutputModelOutStream() const { // class OutStreamBuf OutStreamBuf::OutStreamBuf(OutStreamHolder out_stream_holder) : out_stream_holder_(out_stream_holder) { - setp(buffer_.data(), buffer_.data() + buffer_.size() - 1); // Leave room for overflow character + setp(buffer_.data(), buffer_.data() + buffer_.size()); } OutStreamBuf::~OutStreamBuf() { @@ -54,23 +54,19 @@ OutStreamBuf::~OutStreamBuf() { } std::streambuf::int_type OutStreamBuf::overflow(std::streambuf::int_type ch) { + if (sync() == -1) { + return traits_type::eof(); + } + if (ch != traits_type::eof()) { *pptr() = static_cast(ch); pbump(1); } - if (FlushBuffer() == -1) { - return traits_type::eof(); - } - return ch; } int OutStreamBuf::sync() { - return FlushBuffer(); -} - -int OutStreamBuf::FlushBuffer() { std::ptrdiff_t num_bytes = pptr() - pbase(); if (num_bytes == 0) { return 0; diff --git a/onnxruntime/core/framework/ep_context_options.h b/onnxruntime/core/framework/ep_context_options.h index bb59aec4586ce..f85409dbb2fb1 100644 --- a/onnxruntime/core/framework/ep_context_options.h +++ b/onnxruntime/core/framework/ep_context_options.h @@ -78,8 +78,6 @@ class OutStreamBuf : public std::streambuf { int sync() override; private: - int FlushBuffer(); - OutStreamHolder out_stream_holder_{}; std::array buffer_{}; Status last_status_{}; diff --git a/onnxruntime/core/session/compile_api.cc b/onnxruntime/core/session/compile_api.cc index fef3ffd45796a..9166ce4b11f41 100644 --- a/onnxruntime/core/session/compile_api.cc +++ b/onnxruntime/core/session/compile_api.cc @@ -209,16 +209,16 @@ ORT_API_STATUS_IMPL(OrtCompileAPI::ModelCompilationOptions_SetOutputModelBuffer, ORT_API_STATUS_IMPL(OrtCompileAPI::ModelCompilationOptions_SetOutputModelOutStream, _In_ OrtModelCompilationOptions* ort_model_compile_options, - _In_ OrtOutStreamWriteFunc write_stream_func, _In_ void* state) { + _In_ OrtOutStreamWriteFunc write_stream_func, _In_ void* stream_state) { API_IMPL_BEGIN #if !defined(ORT_MINIMAL_BUILD) auto model_compile_options = reinterpret_cast(ort_model_compile_options); if (write_stream_func == nullptr) { - return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, "WriteToStreamFunc function for output model is null"); + return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, "OrtOutStreamWriteFunc function for output model is null"); } - model_compile_options->SetOutputModelOutStream(write_stream_func, state); + model_compile_options->SetOutputModelOutStream(write_stream_func, stream_state); return nullptr; #else ORT_UNUSED_PARAMETER(ort_model_compile_options); diff --git a/onnxruntime/core/session/compile_api.h b/onnxruntime/core/session/compile_api.h index 1be772e800ac2..231bd9a82eab5 100644 --- a/onnxruntime/core/session/compile_api.h +++ b/onnxruntime/core/session/compile_api.h @@ -31,6 +31,6 @@ ORT_API_STATUS_IMPL(CompileModel, _In_ const OrtEnv* env, _In_ const OrtModelCom ORT_API_STATUS_IMPL(ModelCompilationOptions_SetInputModel, _In_ OrtModelCompilationOptions* model_compile_options, _In_ const OrtModel* input_model); ORT_API_STATUS_IMPL(ModelCompilationOptions_SetOutputModelOutStream, _In_ OrtModelCompilationOptions* model_compile_options, - _In_ OrtOutStreamWriteFunc write_stream_func, _In_ void* state); + _In_ OrtOutStreamWriteFunc write_stream_func, _In_ void* stream_state); } // namespace OrtCompileAPI diff --git a/onnxruntime/test/providers/qnn/qnn_ep_context_test.cc b/onnxruntime/test/providers/qnn/qnn_ep_context_test.cc index 27e635ef29f7a..1b3f300b56b5e 100644 --- a/onnxruntime/test/providers/qnn/qnn_ep_context_test.cc +++ b/onnxruntime/test/providers/qnn/qnn_ep_context_test.cc @@ -558,7 +558,7 @@ TEST_F(QnnHTPBackendTests, CompileApi_FromSessionOptions_OutputModelBuffer_Outpu allocator.Free(output_model_buffer); } -// Implementation of OrtOutStreamWriteFunc that writes out model to a file. +// Implementation of OrtOutStreamWriteFunc that writes the compiled model to a file. static OrtStatus* ORT_API_CALL TestWriteToStream(void* stream_state, const void* buffer, size_t buffer_num_bytes, size_t* num_bytes_written) { std::ofstream* outfile = reinterpret_cast(stream_state); @@ -594,7 +594,7 @@ static OrtStatus* ORT_API_CALL ReturnStatusFromStream(void* stream_state, const } // Test using the CompileModel() API with settings: -// - input OrtModel created via OrtModelEditor API +// - input OrtModel created via the model editor API // - write output model to custom stream TEST_F(QnnHTPBackendTests, CompileApi_InputOrtModel_OutputToStream) { std::vector>> weights; // Model weights must remain valid through inference From 1ee829fb34610f3852391db9b606ab4cd725c02e Mon Sep 17 00:00:00 2001 From: adrianlizarraga Date: Wed, 14 May 2025 12:38:33 -0700 Subject: [PATCH 09/14] Handle case where user's write function never writes data (return error) --- .../core/session/onnxruntime_c_api.h | 18 +++++--- .../core/framework/ep_context_options.cc | 17 +++++++- .../test/providers/qnn/qnn_ep_context_test.cc | 43 +++++++++++++++++++ 3 files changed, 72 insertions(+), 6 deletions(-) diff --git a/include/onnxruntime/core/session/onnxruntime_c_api.h b/include/onnxruntime/core/session/onnxruntime_c_api.h index f57b535a4230e..8837b89c6583d 100644 --- a/include/onnxruntime/core/session/onnxruntime_c_api.h +++ b/include/onnxruntime/core/session/onnxruntime_c_api.h @@ -5847,8 +5847,9 @@ struct OrtCompileApi { /** \brief Sets the file path to the input ONNX model to compile. * - * The input model's location (e.g., file path or memory buffer) must be set with either - * ModelCompilationOptions_SetInputModelPath or ModelCompilationOptions_SetInputModelFromBuffer. + * The input model's location (e.g., file path, memory buffer, or OrtModel) must be set with one of the functions that + * begin with ModelCompilationOptions_SetInputModel____, otherwise CompileModel() will return an OrtStatus with error + * code ORT_INVALID_ARGUMENT. * * \param[in] model_compile_options The OrtModelCompilationOptions instance. * \param[in] input_model_path Null terminated string of the path (wchar on Windows, char otherwise). @@ -5862,8 +5863,9 @@ struct OrtCompileApi { /** \brief Sets the buffer that stores the bytes of the loaded ONNX model to compile. * - * The input model's location (e.g., file path or memory buffer) must be set with either - * ModelCompilationOptions_SetInputModelPath or ModelCompilationOptions_SetInputModelFromBuffer. + * The input model's location (e.g., file path, memory buffer, or OrtModel) must be set with one of the functions that + * begin with ModelCompilationOptions_SetInputModel____, otherwise CompileModel() will return an OrtStatus with error + * code ORT_INVALID_ARGUMENT. * * \param[in] model_compile_options The OrtModelCompilationOptions instance. * \param[in] input_model_data Buffer containing the loaded ONNX model bytes. @@ -5982,6 +5984,10 @@ struct OrtCompileApi { ORT_API2_STATUS(CompileModel, _In_ const OrtEnv* env, _In_ const OrtModelCompilationOptions* model_options); /** \brief Sets the input OrtModel instance to compile. + * + * The input model's location (e.g., file path, memory buffer, or OrtModel) must be set with one of the functions that + * begin with ModelCompilationOptions_SetInputModel____, otherwise CompileModel() will return an OrtStatus with error + * code ORT_INVALID_ARGUMENT. * * \param[in] model_compile_options The OrtModelCompilationOptions instance. * \param[in] input_model The OrtModel instance to compile. @@ -5995,7 +6001,9 @@ struct OrtCompileApi { /** \brief Sets an output stream used to write out the output model's serialized ONNX bytes. * - * The write function is called repeatedly until then entire output model has been written out. + * The provided write function is called repeatedly until then entire output model has been written out. Each call to + * the write function must write at least one byte, otherwise CompileModel() will return an OrtStatus with error code + * ORT_FAIL. * * The output model's destination (e.g., file path, memory buffer, or stream) can be set with any of the functions * that begin with ModelCompilationOptions_SetOutputModel____. diff --git a/onnxruntime/core/framework/ep_context_options.cc b/onnxruntime/core/framework/ep_context_options.cc index df2f3c6a42456..97a135c761d9d 100644 --- a/onnxruntime/core/framework/ep_context_options.cc +++ b/onnxruntime/core/framework/ep_context_options.cc @@ -2,6 +2,9 @@ // Licensed under the MIT License. #include +#include +#include +#include #include "core/common/common.h" #include "core/framework/ep_context_options.h" #include "core/framework/error_code_helper.h" @@ -53,6 +56,7 @@ OutStreamBuf::~OutStreamBuf() { sync(); } +// Called when the buffer_ is full. Flushes the buffer_ (via sync()) and then writes the overflow character to buffer_. std::streambuf::int_type OutStreamBuf::overflow(std::streambuf::int_type ch) { if (sync() == -1) { return traits_type::eof(); @@ -66,13 +70,18 @@ std::streambuf::int_type OutStreamBuf::overflow(std::streambuf::int_type ch) { return ch; } +// Flushes the entire buffer_ to the user's write function. int OutStreamBuf::sync() { + if (!last_status_.IsOK()) { + return -1; + } + std::ptrdiff_t num_bytes = pptr() - pbase(); if (num_bytes == 0) { return 0; } - // Can only call pbump() with an int, so can only write at most 2^31 - 1. + // Can only call pbump() with an int, so can only write at most (2^31 - 1) bytes. if (num_bytes > std::numeric_limits::max()) { num_bytes = std::numeric_limits::max(); } @@ -106,6 +115,12 @@ int OutStreamBuf::sync() { return -1; } + if (bytes_written == 0) { + last_status_ = ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "OrtOutStreamWriteFunc failed to write any data. ", + "Stopping write attempts to avoid a potential infinite loop."); + return -1; + } + bytes_remaining -= static_cast(bytes_written); ptr += bytes_written; } diff --git a/onnxruntime/test/providers/qnn/qnn_ep_context_test.cc b/onnxruntime/test/providers/qnn/qnn_ep_context_test.cc index 1b3f300b56b5e..a4d088cf8dc62 100644 --- a/onnxruntime/test/providers/qnn/qnn_ep_context_test.cc +++ b/onnxruntime/test/providers/qnn/qnn_ep_context_test.cc @@ -593,6 +593,18 @@ static OrtStatus* ORT_API_CALL ReturnStatusFromStream(void* stream_state, const return Ort::GetApi().CreateStatus(ORT_FAIL, "Error from OrtOutStreamWriteFunc callback"); } +// Implementation of OrtOutStreamWriteFunc that never writes any data. ORT should abort writing attempts to prevent +// an infinite loop. +static OrtStatus* ORT_API_CALL NoWriteStream(void* stream_state, const void* buffer, size_t buffer_num_bytes, + size_t* num_bytes_written) { + ORT_UNUSED_PARAMETER(stream_state); + ORT_UNUSED_PARAMETER(buffer); + ORT_UNUSED_PARAMETER(buffer_num_bytes); + + *num_bytes_written = 0; + return nullptr; +} + // Test using the CompileModel() API with settings: // - input OrtModel created via the model editor API // - write output model to custom stream @@ -728,6 +740,37 @@ TEST_F(QnnHTPBackendTests, CompileApi_OutputStream_ReturnStatus) { EXPECT_EQ(status.GetErrorMessage(), "Error from OrtOutStreamWriteFunc callback"); } +// Tests using an OrtOutStreamFunc function that never writes any data. ORT should abort write attempts +// with an error to prevent a potential infinite loop. +TEST_F(QnnHTPBackendTests, CompileApi_OutputStream_NoWrite_AbortInfiniteWriteLoop) { + const ORTCHAR_T* input_model_file = ORT_TSTR("./compileapi_outputstream_zerowrite.onnx"); + std::filesystem::remove(input_model_file); + + // Create a test model and save it to a file. + TestModel test_model; + CreateTestModel(BuildGraphWithQAndNonQ(false), 21, logging::Severity::kERROR, test_model); + ASSERT_STATUS_OK(test_model.Save(input_model_file)); + + // Initialize session options with QNN EP + Ort::SessionOptions so; + ProviderOptions provider_options; + provider_options["backend_type"] = "htp"; + provider_options["offload_graph_io_quantization"] = "0"; + so.AppendExecutionProvider("QNN", provider_options); + + // Create model compilation options from the session options. + Ort::ModelCompilationOptions compile_options(*ort_env, so); + compile_options.SetInputModelPath(input_model_file); + compile_options.SetOutputModelOutStream(NoWriteStream, nullptr); // Set output stream that doesn't write data. + compile_options.SetEpContextEmbedMode(true); + + // Compile the model. Expect an error status because our stream would be stuck in an infinite loop. + Ort::Status status = Ort::CompileModel(*ort_env, compile_options); + ASSERT_FALSE(status.IsOK()); + EXPECT_EQ(status.GetErrorCode(), ORT_FAIL); + EXPECT_TRUE(status.GetErrorMessage().find("OrtOutStreamWriteFunc failed to write any data") != std::string::npos); +} + // Test that models with 1 non-quantized FusedMatMul node and 1 quantized Add node can still generate the context binary // The generated Onnx model has 1 FusedMatMul node and 1 EPContext node TEST_F(QnnHTPBackendTests, QnnContextBinaryMultiPartitionSupport1) { From a21be474939a76143275521a08a6a33b6134f7e9 Mon Sep 17 00:00:00 2001 From: adrianlizarraga Date: Wed, 14 May 2025 13:01:58 -0700 Subject: [PATCH 10/14] Reduce repetition in unit tests --- .../core/framework/ep_context_options.h | 2 +- .../test/providers/qnn/qnn_ep_context_test.cc | 55 +++++++------------ 2 files changed, 22 insertions(+), 35 deletions(-) diff --git a/onnxruntime/core/framework/ep_context_options.h b/onnxruntime/core/framework/ep_context_options.h index f85409dbb2fb1..94fc8c7b0a895 100644 --- a/onnxruntime/core/framework/ep_context_options.h +++ b/onnxruntime/core/framework/ep_context_options.h @@ -66,7 +66,7 @@ struct ModelGenOptions { // std::ostream out_stream(out_stream_buf.get()); class OutStreamBuf : public std::streambuf { public: - OutStreamBuf(OutStreamHolder out_stream_holder); + explicit OutStreamBuf(OutStreamHolder out_stream_holder); ~OutStreamBuf(); Status GetStatus() const { diff --git a/onnxruntime/test/providers/qnn/qnn_ep_context_test.cc b/onnxruntime/test/providers/qnn/qnn_ep_context_test.cc index a4d088cf8dc62..d40005b053f06 100644 --- a/onnxruntime/test/providers/qnn/qnn_ep_context_test.cc +++ b/onnxruntime/test/providers/qnn/qnn_ep_context_test.cc @@ -26,6 +26,14 @@ namespace test { #if defined(__aarch64__) || defined(_M_ARM64) || defined(__linux__) +// Returns QNN provider options that use the HTP backend (npu) and do not offload graph I/O qdq. +static ProviderOptions QnnHTPOptionsWithoutQDQOffloading() { + ProviderOptions provider_options; + provider_options["backend_type"] = "htp"; + provider_options["offload_graph_io_quantization"] = "0"; + return provider_options; +} + static int64_t GetNodeAttr(const Node& node, const std::string& attr_name, int64_t default_val) { const auto& attributes = node.GetAttributes(); if (auto entry = attributes.find(attr_name); entry != attributes.end()) { @@ -649,10 +657,7 @@ TEST_F(QnnHTPBackendTests, CompileApi_InputOrtModel_OutputToStream) { // Initialize session options with QNN EP Ort::SessionOptions so; - ProviderOptions provider_options; - provider_options["backend_type"] = "htp"; - provider_options["offload_graph_io_quantization"] = "0"; - so.AppendExecutionProvider("QNN", provider_options); + so.AppendExecutionProvider("QNN", QnnHTPOptionsWithoutQDQOffloading()); const ORTCHAR_T* output_model_file = ORT_TSTR("compileapi_ortmodel_ctx.onnx"); std::filesystem::remove(output_model_file); @@ -682,24 +687,18 @@ TEST_F(QnnHTPBackendTests, CompileApi_InputOrtModel_OutputToStream) { // Tests using an OrtOutStreamFunc function that writes too much. TEST_F(QnnHTPBackendTests, CompileApi_OutputStream_WriteTooMuch) { - const ORTCHAR_T* input_model_file = ORT_TSTR("./compileapi_outputstream_writetoomuch.onnx"); - std::filesystem::remove(input_model_file); - - // Create a test model and save it to a file. + // Create a test model (in memory). TestModel test_model; CreateTestModel(BuildGraphWithQAndNonQ(false), 21, logging::Severity::kERROR, test_model); - ASSERT_STATUS_OK(test_model.Save(input_model_file)); + std::string model_data = test_model.Serialize(); // Initialize session options with QNN EP Ort::SessionOptions so; - ProviderOptions provider_options; - provider_options["backend_type"] = "htp"; - provider_options["offload_graph_io_quantization"] = "0"; - so.AppendExecutionProvider("QNN", provider_options); + so.AppendExecutionProvider("QNN", QnnHTPOptionsWithoutQDQOffloading()); // Create model compilation options from the session options. Ort::ModelCompilationOptions compile_options(*ort_env, so); - compile_options.SetInputModelPath(input_model_file); + compile_options.SetInputModelFromBuffer(reinterpret_cast(model_data.data()), model_data.size()); compile_options.SetOutputModelOutStream(WriteTooMuchToStream, nullptr); // Set output stream that writes too much compile_options.SetEpContextEmbedMode(true); @@ -712,24 +711,18 @@ TEST_F(QnnHTPBackendTests, CompileApi_OutputStream_WriteTooMuch) { // Tests using an OrtOutStreamFunc function that returns an error. TEST_F(QnnHTPBackendTests, CompileApi_OutputStream_ReturnStatus) { - const ORTCHAR_T* input_model_file = ORT_TSTR("./compileapi_outputstream_returnstatus.onnx"); - std::filesystem::remove(input_model_file); - - // Create a test model and save it to a file. + // Create a test model (in memory). TestModel test_model; CreateTestModel(BuildGraphWithQAndNonQ(false), 21, logging::Severity::kERROR, test_model); - ASSERT_STATUS_OK(test_model.Save(input_model_file)); + std::string model_data = test_model.Serialize(); // Initialize session options with QNN EP Ort::SessionOptions so; - ProviderOptions provider_options; - provider_options["backend_type"] = "htp"; - provider_options["offload_graph_io_quantization"] = "0"; - so.AppendExecutionProvider("QNN", provider_options); + so.AppendExecutionProvider("QNN", QnnHTPOptionsWithoutQDQOffloading()); // Create model compilation options from the session options. Ort::ModelCompilationOptions compile_options(*ort_env, so); - compile_options.SetInputModelPath(input_model_file); + compile_options.SetInputModelFromBuffer(reinterpret_cast(model_data.data()), model_data.size()); compile_options.SetOutputModelOutStream(ReturnStatusFromStream, nullptr); // Set output stream that returns error compile_options.SetEpContextEmbedMode(true); @@ -743,24 +736,18 @@ TEST_F(QnnHTPBackendTests, CompileApi_OutputStream_ReturnStatus) { // Tests using an OrtOutStreamFunc function that never writes any data. ORT should abort write attempts // with an error to prevent a potential infinite loop. TEST_F(QnnHTPBackendTests, CompileApi_OutputStream_NoWrite_AbortInfiniteWriteLoop) { - const ORTCHAR_T* input_model_file = ORT_TSTR("./compileapi_outputstream_zerowrite.onnx"); - std::filesystem::remove(input_model_file); - - // Create a test model and save it to a file. + // Create a test model (in memory). TestModel test_model; CreateTestModel(BuildGraphWithQAndNonQ(false), 21, logging::Severity::kERROR, test_model); - ASSERT_STATUS_OK(test_model.Save(input_model_file)); + std::string model_data = test_model.Serialize(); // Initialize session options with QNN EP Ort::SessionOptions so; - ProviderOptions provider_options; - provider_options["backend_type"] = "htp"; - provider_options["offload_graph_io_quantization"] = "0"; - so.AppendExecutionProvider("QNN", provider_options); + so.AppendExecutionProvider("QNN", QnnHTPOptionsWithoutQDQOffloading()); // Create model compilation options from the session options. Ort::ModelCompilationOptions compile_options(*ort_env, so); - compile_options.SetInputModelPath(input_model_file); + compile_options.SetInputModelFromBuffer(reinterpret_cast(model_data.data()), model_data.size()); compile_options.SetOutputModelOutStream(NoWriteStream, nullptr); // Set output stream that doesn't write data. compile_options.SetEpContextEmbedMode(true); From 78df685969d9d2fb02b68af80523cd384bfd2b37 Mon Sep 17 00:00:00 2001 From: adrianlizarraga Date: Wed, 14 May 2025 14:35:08 -0700 Subject: [PATCH 11/14] Add Python bindings for compiling to a stream --- .../onnxruntime_inference_collection.py | 12 ++- .../onnxruntime_pybind_model_compiler.cc | 42 +++++++++++ .../onnxruntime_pybind_model_compiler.h | 18 ++++- .../python/onnxruntime_pybind_state.cc | 14 +++- .../onnxruntime_test_python_compile_api.py | 73 ++++++++++++++++++- 5 files changed, 154 insertions(+), 5 deletions(-) diff --git a/onnxruntime/python/onnxruntime_inference_collection.py b/onnxruntime/python/onnxruntime_inference_collection.py index e7dc4294f3672..fcb985ed32898 100644 --- a/onnxruntime/python/onnxruntime_inference_collection.py +++ b/onnxruntime/python/onnxruntime_inference_collection.py @@ -9,7 +9,7 @@ import os import typing import warnings -from collections.abc import Sequence +from collections.abc import Callable, Sequence from typing import Any from onnxruntime.capi import _pybind_state as C @@ -726,6 +726,16 @@ def compile_to_bytes(self) -> bytes: """ return self._model_compiler.compile_to_bytes() + def compile_to_stream(self, write_function: Callable[[bytes], int]): + """ + Compiles the input model and writes the serialized ONNX bytes to a stream using the provided write function. + + Raises an 'InvalidArgument' exception if the compilation options are invalid. + + :param write_function: A callable that accepts a bytes buffer to write and returns the number of bytes written. + """ + self._model_compiler.compile_to_stream(write_function) + class IOBinding: """ diff --git a/onnxruntime/python/onnxruntime_pybind_model_compiler.cc b/onnxruntime/python/onnxruntime_pybind_model_compiler.cc index 8bb7ee2098caf..c04e520737865 100644 --- a/onnxruntime/python/onnxruntime_pybind_model_compiler.cc +++ b/onnxruntime/python/onnxruntime_pybind_model_compiler.cc @@ -1,6 +1,7 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // SPDX-FileCopyrightText: Copyright 2024 Arm Limited and/or its affiliates // Licensed under the MIT License. +#if !defined(ORT_MINIMAL_BUILD) #include "python/onnxruntime_pybind_model_compiler.h" #include @@ -72,9 +73,50 @@ onnxruntime::Status PyModelCompiler::CompileToBytes(std::string& output_buffer) return Status::OK(); } +/** + * Calls the user's Python PyOutStreamWriteFunc function and converts the results to a form that can be used + * by ORT to write out a compiled ONNX model. + * + * @param stream_state Opaque state that holds a pointer to the user's Python function. + * @param buffer The buffer to write out. Contains a portion of the compiled ONNX model's bytes. + * @param buffer_num_bytes The number of bytes to write out. + * @param num_bytes_written Output parameter set to the actual number of bytes written by the user's Python function. + * + * @return nullptr OrtStatus* to indicate success. + */ +static OrtStatus* ORT_API_CALL PyOutStreamWriteFuncWrapper(void* stream_state, const void* buffer, + size_t buffer_num_bytes, size_t* num_bytes_written) { + PyOutStreamWriteFunc* py_write_func = reinterpret_cast(stream_state); + OrtStatus* status = nullptr; + + // Call the Python write function and convert any exceptions to a status. + *num_bytes_written = 0; + ORT_TRY { + pybind11::bytes py_bytes(reinterpret_cast(buffer), buffer_num_bytes); + *num_bytes_written = (*py_write_func)(py_bytes); + } + ORT_CATCH(const std::exception& e) { + ORT_HANDLE_EXCEPTION([&]() { + status = ToOrtStatus(ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, e.what())); + }); + } + + if (status != nullptr) { + return status; + } + return nullptr; +} + +onnxruntime::Status PyModelCompiler::CompileToOutStream(PyOutStreamWriteFunc& write_func) { + model_compile_options_.SetOutputModelOutStream(PyOutStreamWriteFuncWrapper, reinterpret_cast(&write_func)); + ORT_RETURN_IF_ERROR(onnxruntime::CompileModel(*env_, model_compile_options_)); + return Status::OK(); +} + PyModelCompiler::PyModelCompiler(std::shared_ptr env, const PySessionOptions& sess_options, PrivateConstructorTag) : env_(env), model_compile_options_(*env, sess_options) { } } // namespace python } // namespace onnxruntime +#endif // !defined(ORT_MINIMAL_BUILD) diff --git a/onnxruntime/python/onnxruntime_pybind_model_compiler.h b/onnxruntime/python/onnxruntime_pybind_model_compiler.h index 6c9f48fa00ba6..d5020e6544c18 100644 --- a/onnxruntime/python/onnxruntime_pybind_model_compiler.h +++ b/onnxruntime/python/onnxruntime_pybind_model_compiler.h @@ -3,7 +3,6 @@ // Licensed under the MIT License. #pragma once -#if !defined(ORT_MINIMAL_BUILD) #include #include #include "core/common/status.h" @@ -14,11 +13,19 @@ namespace onnxruntime { class Environment; namespace python { +// Type of the function provided by Python code that is called by ORT to write out the compiled model. +// Returns the number of bytes written to the underlying stream. +using PyOutStreamWriteFunc = std::function; + /// /// Class exposed to Python that enables compiling ONNX models. /// Internally wraps a onnxruntime::ModelCompilationOptions that stores and validates settings. /// class PyModelCompiler { +#if defined(ORT_MINIMAL_BUILD) + public: + bool not_defined_in_this_build{}; // Prevent empty class warning. +#else private: // private tag to pass to constructor to ensure that constructor cannot be directly called externally struct PrivateConstructorTag {}; @@ -68,11 +75,18 @@ class PyModelCompiler { /// A Status indicating error or success. onnxruntime::Status CompileToBytes(std::string& output_buffer); + /// + /// Compiles the input model and writes the result into the provided output stream (write functor). + /// + /// Write functor that encapsulates the stream's state. + /// A Status indicating error or success. + onnxruntime::Status CompileToOutStream(PyOutStreamWriteFunc& write_func); + private: std::shared_ptr env_; onnxruntime::ModelCompilationOptions model_compile_options_; std::string input_model_bytes_; +#endif // !defined(ORT_MINIMAL_BUILD) }; } // namespace python } // namespace onnxruntime -#endif // !defined(ORT_MINIMAL_BUILD) diff --git a/onnxruntime/python/onnxruntime_pybind_state.cc b/onnxruntime/python/onnxruntime_pybind_state.cc index aa2c0cc6a0f86..9e19561749ca1 100644 --- a/onnxruntime/python/onnxruntime_pybind_state.cc +++ b/onnxruntime/python/onnxruntime_pybind_state.cc @@ -2832,7 +2832,19 @@ including arg name, arg type (contains both type and shape).)pbdoc") ORT_THROW("Compile API is not supported in this build."); #endif }, - R"pbdoc(Compile an ONNX model into a buffer.)pbdoc"); + R"pbdoc(Compile an ONNX model into a buffer.)pbdoc") + .def( + "compile_to_stream", + [](PyModelCompiler* model_compiler, PyOutStreamWriteFunc& py_stream_write_func) { +#if !defined(ORT_MINIMAL_BUILD) + OrtPybindThrowIfError(model_compiler->CompileToOutStream(py_stream_write_func)); +#else + ORT_UNUSED_PARAMETER(model_compiler); + ORT_UNUSED_PARAMETER(py_stream_write_func); + ORT_THROW("Compile API is not supported in this build."); +#endif + }, + R"pbdoc(Compile an ONNX model into an output stream using the provided write functor.)pbdoc"); } bool CreateInferencePybindStateModule(py::module& m) { diff --git a/onnxruntime/test/python/onnxruntime_test_python_compile_api.py b/onnxruntime/test/python/onnxruntime_test_python_compile_api.py index 7a410d4bbeb6a..7970917886c48 100644 --- a/onnxruntime/test/python/onnxruntime_test_python_compile_api.py +++ b/onnxruntime/test/python/onnxruntime_test_python_compile_api.py @@ -13,7 +13,7 @@ from helper import get_name import onnxruntime as onnxrt -from onnxruntime.capi.onnxruntime_pybind11_state import ModelRequiresCompilation +from onnxruntime.capi.onnxruntime_pybind11_state import Fail, ModelRequiresCompilation # handle change from python 3.8 and on where loading a dll from the current directory needs to be explicitly allowed. if platform.system() == "Windows" and sys.version_info.major >= 3 and sys.version_info.minor >= 8: # noqa: YTT204 @@ -176,6 +176,77 @@ def test_compile_from_buffer_to_buffer(self): self.assertTrue(isinstance(output_model_bytes, bytes)) self.assertGreater(len(output_model_bytes), 0) + def test_compile_from_file_to_stream(self): + """ + Tests compiling a model (from files) to an output stream using a custom write functor. + """ + provider = None + provider_options = dict() + if "QNNExecutionProvider" in available_providers: + provider = "QNNExecutionProvider" + provider_options["backend_type"] = "htp" + # TODO(adrianlizarraga): Allow test to run for other compiling EPs (e.g., OpenVINO) + + input_model_path = get_name("nhwc_resize_scales_opset18.onnx") + output_model_path = os.path.join(self._tmp_dir_path, "model.compiled.stream.onnx") + + with open(output_model_path, "wb") as output_fd: + # User's custom write functor. Writes the model to a file. + def my_write_func(buffer: bytes) -> int: + self.assertGreater(len(buffer), 0) + return output_fd.write(buffer) + + session_options = onnxrt.SessionOptions() + if provider: + session_options.add_provider(provider, provider_options) + + model_compiler = onnxrt.ModelCompiler( + session_options, + input_model_path, + embed_compiled_data_into_model=True, + external_initializers_file_path=None, + ) + model_compiler.compile_to_stream(my_write_func) + + self.assertTrue(os.path.exists(output_model_path)) + onnx.checker.check_model(output_model_path) # Check that compiled model can be loaded. + + def test_compile_to_stream_that_raises_exception(self): + """ + Tests compiling a model to an output stream that always raises an exception. + """ + provider = None + provider_options = dict() + if "QNNExecutionProvider" in available_providers: + provider = "QNNExecutionProvider" + provider_options["backend_type"] = "htp" + # TODO(adrianlizarraga): Allow test to run for other compiling EPs (e.g., OpenVINO) + + input_model_path = get_name("nhwc_resize_scales_opset18.onnx") + + # User's custom write functor that raises an exception. + test_py_error_message = "My Python Error" + + def my_write_func(buffer: bytes) -> int: + self.assertGreater(len(buffer), 0) + raise ValueError(test_py_error_message) + + session_options = onnxrt.SessionOptions() + if provider: + session_options.add_provider(provider, provider_options) + + model_compiler = onnxrt.ModelCompiler( + session_options, + input_model_path, + embed_compiled_data_into_model=True, + external_initializers_file_path=None, + ) + + # Try to compile and expect ORT to raise a Fail exception that contains our message. + with self.assertRaises(Fail) as context: + model_compiler.compile_to_stream(my_write_func) + self.assertIn(test_py_error_message, str(context.exception)) + def test_fail_load_uncompiled_model_and_then_compile(self): """ Tests compiling scenario: From 09a90a17dfc0242cbdd913f7ce50664b50b657a4 Mon Sep 17 00:00:00 2001 From: adrianlizarraga Date: Wed, 14 May 2025 14:49:43 -0700 Subject: [PATCH 12/14] fix min build unused var --- onnxruntime/core/session/compile_api.cc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/onnxruntime/core/session/compile_api.cc b/onnxruntime/core/session/compile_api.cc index 9166ce4b11f41..12f1745a99678 100644 --- a/onnxruntime/core/session/compile_api.cc +++ b/onnxruntime/core/session/compile_api.cc @@ -223,7 +223,7 @@ ORT_API_STATUS_IMPL(OrtCompileAPI::ModelCompilationOptions_SetOutputModelOutStre #else ORT_UNUSED_PARAMETER(ort_model_compile_options); ORT_UNUSED_PARAMETER(write_stream_func); - ORT_UNUSED_PARAMETER(state); + ORT_UNUSED_PARAMETER(stream_state); return OrtApis::CreateStatus(ORT_NOT_IMPLEMENTED, "Compile API is not supported in this build"); #endif // !defined(ORT_MINIMAL_BUILD) API_IMPL_END From d12f43524ae99074409cd75c12f33dc735f07096 Mon Sep 17 00:00:00 2001 From: adrianlizarraga Date: Wed, 14 May 2025 17:32:30 -0700 Subject: [PATCH 13/14] Flush before checking status of write --- onnxruntime/core/framework/ep_context_options.h | 2 +- onnxruntime/core/framework/graph_partitioner.cc | 1 + .../python/onnxruntime_pybind_model_compiler.cc | 5 +---- .../python/onnxruntime_test_python_compile_api.py | 12 ------------ 4 files changed, 3 insertions(+), 17 deletions(-) diff --git a/onnxruntime/core/framework/ep_context_options.h b/onnxruntime/core/framework/ep_context_options.h index 94fc8c7b0a895..7e46e664f90ce 100644 --- a/onnxruntime/core/framework/ep_context_options.h +++ b/onnxruntime/core/framework/ep_context_options.h @@ -69,7 +69,7 @@ class OutStreamBuf : public std::streambuf { explicit OutStreamBuf(OutStreamHolder out_stream_holder); ~OutStreamBuf(); - Status GetStatus() const { + const Status& GetStatus() const { return last_status_; } diff --git a/onnxruntime/core/framework/graph_partitioner.cc b/onnxruntime/core/framework/graph_partitioner.cc index 7a53563babddf..9465b8a7d4c87 100644 --- a/onnxruntime/core/framework/graph_partitioner.cc +++ b/onnxruntime/core/framework/graph_partitioner.cc @@ -926,6 +926,7 @@ static Status CreateEpContextModel(const ExecutionProviders& execution_providers std::ostream out_stream(out_stream_buf.get()); model_proto.SerializeToOstream(&out_stream); + out_stream.flush(); ORT_RETURN_IF_ERROR(out_stream_buf->GetStatus()); } else { // Write output model to file. diff --git a/onnxruntime/python/onnxruntime_pybind_model_compiler.cc b/onnxruntime/python/onnxruntime_pybind_model_compiler.cc index c04e520737865..bc98fa642e0ba 100644 --- a/onnxruntime/python/onnxruntime_pybind_model_compiler.cc +++ b/onnxruntime/python/onnxruntime_pybind_model_compiler.cc @@ -101,10 +101,7 @@ static OrtStatus* ORT_API_CALL PyOutStreamWriteFuncWrapper(void* stream_state, c }); } - if (status != nullptr) { - return status; - } - return nullptr; + return status; } onnxruntime::Status PyModelCompiler::CompileToOutStream(PyOutStreamWriteFunc& write_func) { diff --git a/onnxruntime/test/python/onnxruntime_test_python_compile_api.py b/onnxruntime/test/python/onnxruntime_test_python_compile_api.py index 7970917886c48..cdca3b0ac5521 100644 --- a/onnxruntime/test/python/onnxruntime_test_python_compile_api.py +++ b/onnxruntime/test/python/onnxruntime_test_python_compile_api.py @@ -185,7 +185,6 @@ def test_compile_from_file_to_stream(self): if "QNNExecutionProvider" in available_providers: provider = "QNNExecutionProvider" provider_options["backend_type"] = "htp" - # TODO(adrianlizarraga): Allow test to run for other compiling EPs (e.g., OpenVINO) input_model_path = get_name("nhwc_resize_scales_opset18.onnx") output_model_path = os.path.join(self._tmp_dir_path, "model.compiled.stream.onnx") @@ -209,19 +208,11 @@ def my_write_func(buffer: bytes) -> int: model_compiler.compile_to_stream(my_write_func) self.assertTrue(os.path.exists(output_model_path)) - onnx.checker.check_model(output_model_path) # Check that compiled model can be loaded. def test_compile_to_stream_that_raises_exception(self): """ Tests compiling a model to an output stream that always raises an exception. """ - provider = None - provider_options = dict() - if "QNNExecutionProvider" in available_providers: - provider = "QNNExecutionProvider" - provider_options["backend_type"] = "htp" - # TODO(adrianlizarraga): Allow test to run for other compiling EPs (e.g., OpenVINO) - input_model_path = get_name("nhwc_resize_scales_opset18.onnx") # User's custom write functor that raises an exception. @@ -232,9 +223,6 @@ def my_write_func(buffer: bytes) -> int: raise ValueError(test_py_error_message) session_options = onnxrt.SessionOptions() - if provider: - session_options.add_provider(provider, provider_options) - model_compiler = onnxrt.ModelCompiler( session_options, input_model_path, From e5c27c30ce9b784867198a87295fcdcd0075f9a0 Mon Sep 17 00:00:00 2001 From: adrianlizarraga Date: Wed, 14 May 2025 18:11:40 -0700 Subject: [PATCH 14/14] Simplify: expect write function to write entire provided buffer --- .../core/session/onnxruntime_c_api.h | 9 +- .../core/framework/ep_context_options.cc | 52 ++++------- .../onnxruntime_inference_collection.py | 4 +- .../onnxruntime_pybind_model_compiler.cc | 6 +- .../onnxruntime_pybind_model_compiler.h | 2 +- .../test/providers/qnn/qnn_ep_context_test.cc | 89 +------------------ .../onnxruntime_test_python_compile_api.py | 6 +- 7 files changed, 31 insertions(+), 137 deletions(-) diff --git a/include/onnxruntime/core/session/onnxruntime_c_api.h b/include/onnxruntime/core/session/onnxruntime_c_api.h index 8837b89c6583d..e8420a12fe01e 100644 --- a/include/onnxruntime/core/session/onnxruntime_c_api.h +++ b/include/onnxruntime/core/session/onnxruntime_c_api.h @@ -458,7 +458,6 @@ typedef OrtStatus*(ORT_API_CALL* EpSelectionDelegate)(_In_ const OrtEpDevice** e * \param steam_state Opaque pointer holding the state for the user's stream. * \param buffer The buffer to write to the stream. * \param buffer_num_bytes The size of the buffer in bytes. - * \param num_bytes_written Output parameter that should be set to the number of bytes written to the stream. * * \return OrtStatus* Write status. Return nullptr on success. * Use CreateStatus to provide error info. Use ORT_FAIL as the error code. @@ -466,8 +465,7 @@ typedef OrtStatus*(ORT_API_CALL* EpSelectionDelegate)(_In_ const OrtEpDevice** e */ typedef OrtStatus*(ORT_API_CALL* OrtOutStreamWriteFunc)(_In_ void* stream_state, _In_ const void* buffer, - _In_ size_t buffer_num_bytes, - _Out_ size_t* num_bytes_written); + _In_ size_t buffer_num_bytes); /** \brief Algorithm to use for cuDNN Convolution Op */ @@ -6001,9 +5999,8 @@ struct OrtCompileApi { /** \brief Sets an output stream used to write out the output model's serialized ONNX bytes. * - * The provided write function is called repeatedly until then entire output model has been written out. Each call to - * the write function must write at least one byte, otherwise CompileModel() will return an OrtStatus with error code - * ORT_FAIL. + * The provided write function may be called repeatedly until then entire output model has been written out. Each call + * to the write function is expected to write the entire buffer to the underlying stream. * * The output model's destination (e.g., file path, memory buffer, or stream) can be set with any of the functions * that begin with ModelCompilationOptions_SetOutputModel____. diff --git a/onnxruntime/core/framework/ep_context_options.cc b/onnxruntime/core/framework/ep_context_options.cc index 97a135c761d9d..60d2e25cc619f 100644 --- a/onnxruntime/core/framework/ep_context_options.cc +++ b/onnxruntime/core/framework/ep_context_options.cc @@ -86,46 +86,26 @@ int OutStreamBuf::sync() { num_bytes = std::numeric_limits::max(); } - std::ptrdiff_t bytes_remaining = num_bytes; char* ptr = pbase(); - while (bytes_remaining > 0) { - size_t bytes_written = 0; - Status status = Status::OK(); - - ORT_TRY { - status = ToStatus(out_stream_holder_.write_func(out_stream_holder_.stream_state, - ptr, bytes_remaining, &bytes_written)); - } - ORT_CATCH(const std::exception& e) { - ORT_HANDLE_EXCEPTION([&]() { - status = ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, - "Caught exception while calling user's OrtOutStreamWriteFunc callback: ", e.what()); - }); - } - - if (!status.IsOK()) { - last_status_ = std::move(status); - return -1; - } - - if (bytes_written > static_cast(bytes_remaining)) { - last_status_ = ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "OrtOutStreamWriteFunc wrote more bytes (", bytes_written, - ") than requested (", bytes_remaining, ")."); - return -1; - } - - if (bytes_written == 0) { - last_status_ = ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "OrtOutStreamWriteFunc failed to write any data. ", - "Stopping write attempts to avoid a potential infinite loop."); - return -1; - } - - bytes_remaining -= static_cast(bytes_written); - ptr += bytes_written; + Status status = Status::OK(); + + ORT_TRY { + status = ToStatus(out_stream_holder_.write_func(out_stream_holder_.stream_state, + ptr, num_bytes)); + } + ORT_CATCH(const std::exception& e) { + ORT_HANDLE_EXCEPTION([&]() { + status = ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, + "Caught exception while calling user's OrtOutStreamWriteFunc callback: ", e.what()); + }); + } + + if (!status.IsOK()) { + last_status_ = std::move(status); + return -1; } - assert(ptr == pptr()); pbump(-static_cast(num_bytes)); // Reset internal pointer to point to the beginning of the buffer_ return 0; } diff --git a/onnxruntime/python/onnxruntime_inference_collection.py b/onnxruntime/python/onnxruntime_inference_collection.py index fcb985ed32898..b2ed617a74e8b 100644 --- a/onnxruntime/python/onnxruntime_inference_collection.py +++ b/onnxruntime/python/onnxruntime_inference_collection.py @@ -726,13 +726,13 @@ def compile_to_bytes(self) -> bytes: """ return self._model_compiler.compile_to_bytes() - def compile_to_stream(self, write_function: Callable[[bytes], int]): + def compile_to_stream(self, write_function: Callable[[bytes], None]): """ Compiles the input model and writes the serialized ONNX bytes to a stream using the provided write function. Raises an 'InvalidArgument' exception if the compilation options are invalid. - :param write_function: A callable that accepts a bytes buffer to write and returns the number of bytes written. + :param write_function: A callable that accepts a bytes buffer to write. """ self._model_compiler.compile_to_stream(write_function) diff --git a/onnxruntime/python/onnxruntime_pybind_model_compiler.cc b/onnxruntime/python/onnxruntime_pybind_model_compiler.cc index bc98fa642e0ba..f75931bd6b170 100644 --- a/onnxruntime/python/onnxruntime_pybind_model_compiler.cc +++ b/onnxruntime/python/onnxruntime_pybind_model_compiler.cc @@ -80,20 +80,18 @@ onnxruntime::Status PyModelCompiler::CompileToBytes(std::string& output_buffer) * @param stream_state Opaque state that holds a pointer to the user's Python function. * @param buffer The buffer to write out. Contains a portion of the compiled ONNX model's bytes. * @param buffer_num_bytes The number of bytes to write out. - * @param num_bytes_written Output parameter set to the actual number of bytes written by the user's Python function. * * @return nullptr OrtStatus* to indicate success. */ static OrtStatus* ORT_API_CALL PyOutStreamWriteFuncWrapper(void* stream_state, const void* buffer, - size_t buffer_num_bytes, size_t* num_bytes_written) { + size_t buffer_num_bytes) { PyOutStreamWriteFunc* py_write_func = reinterpret_cast(stream_state); OrtStatus* status = nullptr; // Call the Python write function and convert any exceptions to a status. - *num_bytes_written = 0; ORT_TRY { pybind11::bytes py_bytes(reinterpret_cast(buffer), buffer_num_bytes); - *num_bytes_written = (*py_write_func)(py_bytes); + (*py_write_func)(py_bytes); } ORT_CATCH(const std::exception& e) { ORT_HANDLE_EXCEPTION([&]() { diff --git a/onnxruntime/python/onnxruntime_pybind_model_compiler.h b/onnxruntime/python/onnxruntime_pybind_model_compiler.h index d5020e6544c18..9fb59a33164b1 100644 --- a/onnxruntime/python/onnxruntime_pybind_model_compiler.h +++ b/onnxruntime/python/onnxruntime_pybind_model_compiler.h @@ -15,7 +15,7 @@ class Environment; namespace python { // Type of the function provided by Python code that is called by ORT to write out the compiled model. // Returns the number of bytes written to the underlying stream. -using PyOutStreamWriteFunc = std::function; +using PyOutStreamWriteFunc = std::function; /// /// Class exposed to Python that enables compiling ONNX models. diff --git a/onnxruntime/test/providers/qnn/qnn_ep_context_test.cc b/onnxruntime/test/providers/qnn/qnn_ep_context_test.cc index d40005b053f06..180c1bae7d077 100644 --- a/onnxruntime/test/providers/qnn/qnn_ep_context_test.cc +++ b/onnxruntime/test/providers/qnn/qnn_ep_context_test.cc @@ -567,52 +567,20 @@ TEST_F(QnnHTPBackendTests, CompileApi_FromSessionOptions_OutputModelBuffer_Outpu } // Implementation of OrtOutStreamWriteFunc that writes the compiled model to a file. -static OrtStatus* ORT_API_CALL TestWriteToStream(void* stream_state, const void* buffer, size_t buffer_num_bytes, - size_t* num_bytes_written) { +static OrtStatus* ORT_API_CALL TestWriteToStream(void* stream_state, const void* buffer, size_t buffer_num_bytes) { std::ofstream* outfile = reinterpret_cast(stream_state); - - // Write out file in chunks of at most 2048 bytes to test multiple calls to this function. - size_t write_amount = std::min(static_cast(2048), buffer_num_bytes); - outfile->write(reinterpret_cast(buffer), write_amount); - *num_bytes_written = write_amount; - - return nullptr; -} - -// Implementation of OrtOutStreamWriteFunc that writes too much, resulting in an error. -static OrtStatus* ORT_API_CALL WriteTooMuchToStream(void* stream_state, const void* buffer, size_t buffer_num_bytes, - size_t* num_bytes_written) { - ORT_UNUSED_PARAMETER(stream_state); - ORT_UNUSED_PARAMETER(buffer); - - // Incorrectly say we wrote more than requested. ORT should return an error on call to CompileModel(). - *num_bytes_written = buffer_num_bytes + 1; - return nullptr; + outfile->write(reinterpret_cast(buffer), buffer_num_bytes); + return nullptr; // No error } // Implementation of OrtOutStreamWriteFunc that directly returns an OrtStatus indicating an error. -static OrtStatus* ORT_API_CALL ReturnStatusFromStream(void* stream_state, const void* buffer, size_t buffer_num_bytes, - size_t* num_bytes_written) { +static OrtStatus* ORT_API_CALL ReturnStatusFromStream(void* stream_state, const void* buffer, size_t buffer_num_bytes) { ORT_UNUSED_PARAMETER(stream_state); ORT_UNUSED_PARAMETER(buffer); ORT_UNUSED_PARAMETER(buffer_num_bytes); - - *num_bytes_written = 0; return Ort::GetApi().CreateStatus(ORT_FAIL, "Error from OrtOutStreamWriteFunc callback"); } -// Implementation of OrtOutStreamWriteFunc that never writes any data. ORT should abort writing attempts to prevent -// an infinite loop. -static OrtStatus* ORT_API_CALL NoWriteStream(void* stream_state, const void* buffer, size_t buffer_num_bytes, - size_t* num_bytes_written) { - ORT_UNUSED_PARAMETER(stream_state); - ORT_UNUSED_PARAMETER(buffer); - ORT_UNUSED_PARAMETER(buffer_num_bytes); - - *num_bytes_written = 0; - return nullptr; -} - // Test using the CompileModel() API with settings: // - input OrtModel created via the model editor API // - write output model to custom stream @@ -685,30 +653,6 @@ TEST_F(QnnHTPBackendTests, CompileApi_InputOrtModel_OutputToStream) { CheckEpContextNodeCounts(output_model_file, 1, 0); } -// Tests using an OrtOutStreamFunc function that writes too much. -TEST_F(QnnHTPBackendTests, CompileApi_OutputStream_WriteTooMuch) { - // Create a test model (in memory). - TestModel test_model; - CreateTestModel(BuildGraphWithQAndNonQ(false), 21, logging::Severity::kERROR, test_model); - std::string model_data = test_model.Serialize(); - - // Initialize session options with QNN EP - Ort::SessionOptions so; - so.AppendExecutionProvider("QNN", QnnHTPOptionsWithoutQDQOffloading()); - - // Create model compilation options from the session options. - Ort::ModelCompilationOptions compile_options(*ort_env, so); - compile_options.SetInputModelFromBuffer(reinterpret_cast(model_data.data()), model_data.size()); - compile_options.SetOutputModelOutStream(WriteTooMuchToStream, nullptr); // Set output stream that writes too much - compile_options.SetEpContextEmbedMode(true); - - // Compile the model. Expect an error status because our stream wrote too much. - Ort::Status status = Ort::CompileModel(*ort_env, compile_options); - ASSERT_FALSE(status.IsOK()); - EXPECT_EQ(status.GetErrorCode(), ORT_FAIL); - EXPECT_TRUE(status.GetErrorMessage().find("OrtOutStreamWriteFunc wrote more bytes") != std::string::npos); -} - // Tests using an OrtOutStreamFunc function that returns an error. TEST_F(QnnHTPBackendTests, CompileApi_OutputStream_ReturnStatus) { // Create a test model (in memory). @@ -733,31 +677,6 @@ TEST_F(QnnHTPBackendTests, CompileApi_OutputStream_ReturnStatus) { EXPECT_EQ(status.GetErrorMessage(), "Error from OrtOutStreamWriteFunc callback"); } -// Tests using an OrtOutStreamFunc function that never writes any data. ORT should abort write attempts -// with an error to prevent a potential infinite loop. -TEST_F(QnnHTPBackendTests, CompileApi_OutputStream_NoWrite_AbortInfiniteWriteLoop) { - // Create a test model (in memory). - TestModel test_model; - CreateTestModel(BuildGraphWithQAndNonQ(false), 21, logging::Severity::kERROR, test_model); - std::string model_data = test_model.Serialize(); - - // Initialize session options with QNN EP - Ort::SessionOptions so; - so.AppendExecutionProvider("QNN", QnnHTPOptionsWithoutQDQOffloading()); - - // Create model compilation options from the session options. - Ort::ModelCompilationOptions compile_options(*ort_env, so); - compile_options.SetInputModelFromBuffer(reinterpret_cast(model_data.data()), model_data.size()); - compile_options.SetOutputModelOutStream(NoWriteStream, nullptr); // Set output stream that doesn't write data. - compile_options.SetEpContextEmbedMode(true); - - // Compile the model. Expect an error status because our stream would be stuck in an infinite loop. - Ort::Status status = Ort::CompileModel(*ort_env, compile_options); - ASSERT_FALSE(status.IsOK()); - EXPECT_EQ(status.GetErrorCode(), ORT_FAIL); - EXPECT_TRUE(status.GetErrorMessage().find("OrtOutStreamWriteFunc failed to write any data") != std::string::npos); -} - // Test that models with 1 non-quantized FusedMatMul node and 1 quantized Add node can still generate the context binary // The generated Onnx model has 1 FusedMatMul node and 1 EPContext node TEST_F(QnnHTPBackendTests, QnnContextBinaryMultiPartitionSupport1) { diff --git a/onnxruntime/test/python/onnxruntime_test_python_compile_api.py b/onnxruntime/test/python/onnxruntime_test_python_compile_api.py index cdca3b0ac5521..d7012f1c9d5e8 100644 --- a/onnxruntime/test/python/onnxruntime_test_python_compile_api.py +++ b/onnxruntime/test/python/onnxruntime_test_python_compile_api.py @@ -191,9 +191,9 @@ def test_compile_from_file_to_stream(self): with open(output_model_path, "wb") as output_fd: # User's custom write functor. Writes the model to a file. - def my_write_func(buffer: bytes) -> int: + def my_write_func(buffer: bytes): self.assertGreater(len(buffer), 0) - return output_fd.write(buffer) + output_fd.write(buffer) session_options = onnxrt.SessionOptions() if provider: @@ -218,7 +218,7 @@ def test_compile_to_stream_that_raises_exception(self): # User's custom write functor that raises an exception. test_py_error_message = "My Python Error" - def my_write_func(buffer: bytes) -> int: + def my_write_func(buffer: bytes): self.assertGreater(len(buffer), 0) raise ValueError(test_py_error_message)