diff --git a/src/grpc/CMakeLists.txt b/src/grpc/CMakeLists.txt index 0cd027a30a..1b0544c37c 100644 --- a/src/grpc/CMakeLists.txt +++ b/src/grpc/CMakeLists.txt @@ -1,4 +1,4 @@ -# Copyright 2023-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# Copyright 2023-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # Redistribution and use in source and binary forms, with or without # modification, are permitted provided that the following conditions @@ -67,6 +67,7 @@ target_link_libraries( triton-common-json # from repo-common grpc-health-library # from repo-common grpc-service-library # from repo-common + grpccallback-service-library # from repo-common triton-core-serverapi # from repo-core triton-core-serverstub # from repo-core gRPC::grpc++ diff --git a/src/grpc/grpc_handler.h b/src/grpc/grpc_handler.h index 4f1bcdfac0..405a78d737 100644 --- a/src/grpc/grpc_handler.h +++ b/src/grpc/grpc_handler.h @@ -1,4 +1,4 @@ -// Copyright 2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +// Copyright 2023-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. // // Redistribution and use in source and binary forms, with or without // modification, are permitted provided that the following conditions @@ -25,14 +25,31 @@ // OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. #pragma once +#include <grpc++/grpc++.h> + #include <string> +#include "grpc_service.grpc.pb.h" +#include "grpccallback_service.grpc.pb.h" +#include "health.grpc.pb.h" + namespace triton { namespace server { namespace grpc { class HandlerBase { public: virtual ~HandlerBase() = default; virtual void Start() = 0; virtual void Stop() = 0; + virtual inference::GRPCInferenceServiceCallback::CallbackService* + GetUnifiedCallbackService() + { + return nullptr; + } + + virtual ::grpc::health::v1::Health::CallbackService* + GetHealthCallbackService() + { + return nullptr; + } }; class ICallData { diff --git a/src/grpc/grpc_server.cc b/src/grpc/grpc_server.cc index 362c60e81e..98030b24b9 100644 --- a/src/grpc/grpc_server.cc +++ b/src/grpc/grpc_server.cc @@ -77,622 +77,335 @@ namespace { // are deemed to be not performance critical. //========================================================================= -template <typename ResponderType, typename RequestType, typename ResponseType> -class CommonCallData : public ICallData { +// Define a dedicated health service that implements the health check RPC +class HealthCallbackService + : public ::grpc::health::v1::Health::CallbackService { public: - using StandardRegisterFunc = std::function<void( - ::grpc::ServerContext*, RequestType*, ResponderType*, void*)>; - using StandardCallbackFunc = - std::function<void(RequestType&, ResponseType*, ::grpc::Status*)>; - - CommonCallData( - const std::string& name, const uint64_t id, - const StandardRegisterFunc OnRegister, - const StandardCallbackFunc OnExecute, const bool async, - ::grpc::ServerCompletionQueue* cq, - const std::pair<std::string, std::string>& restricted_kv, - const uint64_t& response_delay = 0) - : name_(name), id_(id), OnRegister_(OnRegister), OnExecute_(OnExecute), - async_(async), cq_(cq), responder_(&ctx_), step_(Steps::START), - restricted_kv_(restricted_kv), response_delay_(response_delay) + HealthCallbackService( + const std::shared_ptr<TRITONSERVER_Server>& server, + RestrictedFeatures& restricted_keys_) + : tritonserver_(server), restricted_keys_(restricted_keys_) { - OnRegister_(&ctx_, &request_, &responder_, this); - LOG_VERBOSE(1) << "Ready for RPC '" << name_ << "', " << id_; } - ~CommonCallData() + ::grpc::ServerUnaryReactor* Check( + ::grpc::CallbackServerContext* context, + const ::grpc::health::v1::HealthCheckRequest* request, + ::grpc::health::v1::HealthCheckResponse* response) override { - if (async_thread_.joinable()) { - async_thread_.join(); - } - } - - bool Process(bool ok) override; - - std::string Name() override { return name_; } - - uint64_t Id() override { return id_; } - - private: - void Execute(); - void AddToCompletionQueue(); - void WriteResponse(); - bool ExecutePrecondition(); - - const std::string name_; - const uint64_t id_; - const StandardRegisterFunc OnRegister_; - const StandardCallbackFunc OnExecute_; - const bool async_; - ::grpc::ServerCompletionQueue* cq_; - - ::grpc::ServerContext ctx_; - ::grpc::Alarm alarm_; - - ResponderType responder_; - RequestType request_; - ResponseType response_; - ::grpc::Status status_; - - std::thread async_thread_; - - Steps step_; - - std::pair<std::string, std::string> restricted_kv_{"", ""}; - - const uint64_t response_delay_; -}; - -template <typename ResponderType, typename RequestType, typename ResponseType> -bool -CommonCallData<ResponderType, RequestType, ResponseType>::Process(bool rpc_ok) -{ - LOG_VERBOSE(1) << "Process for " << name_ << ", rpc_ok=" << rpc_ok << ", " - << id_ << " step " << step_; - - // If RPC failed on a new request then the server is shutting down - // and so we should do nothing (including not registering for a new - // request). If RPC failed on a non-START step then there is nothing - // we can do since we one execute one step. - const bool shutdown = (!rpc_ok && (step_ == Steps::START)); - if (shutdown) { - if (async_thread_.joinable()) { - async_thread_.join(); + auto* reactor = context->DefaultReactor(); + + // Check restricted access if configured + const std::pair<std::string, std::string>& restricted_kv = + restricted_keys_.Get(RestrictedCategory::HEALTH); + if (!restricted_kv.first.empty()) { + const auto& metadata = context->client_metadata(); + auto it = metadata.find(restricted_kv.first); + if (it == metadata.end() || it->second != restricted_kv.second) { + reactor->Finish(::grpc::Status( + ::grpc::StatusCode::UNAVAILABLE, + "Missing or mismatched restricted header")); + return reactor; + } } - step_ = Steps::FINISH; - } - if (step_ == Steps::START) { - // Start a new request to replace this one... - if (!shutdown) { - new CommonCallData<ResponderType, RequestType, ResponseType>( - name_, id_ + 1, OnRegister_, OnExecute_, async_, cq_, restricted_kv_, - response_delay_); - } + // Check if server is ready + bool ready = false; + TRITONSERVER_Error* err = + TRITONSERVER_ServerIsReady(tritonserver_.get(), &ready); - if (!async_) { - // For synchronous calls, execute and write response - // here. - Execute(); - WriteResponse(); + // Set health status based on server readiness + if (err == nullptr && ready) { + response->set_status(::grpc::health::v1::HealthCheckResponse::SERVING); } else { - // For asynchronous calls, delegate the execution to another - // thread. - step_ = Steps::ISSUED; - async_thread_ = std::thread(&CommonCallData::Execute, this); + response->set_status( + ::grpc::health::v1::HealthCheckResponse::NOT_SERVING); } - } else if (step_ == Steps::WRITEREADY) { - // Will only come here for asynchronous mode. - WriteResponse(); - } else if (step_ == Steps::COMPLETE) { - step_ = Steps::FINISH; - } - - return step_ != Steps::FINISH; -} - -template <typename ResponderType, typename RequestType, typename ResponseType> -void -CommonCallData<ResponderType, RequestType, ResponseType>::Execute() -{ - if (ExecutePrecondition()) { - OnExecute_(request_, &response_, &status_); - } else { - status_ = ::grpc::Status( - ::grpc::StatusCode::UNAVAILABLE, - std::string("This protocol is restricted, expecting header '") + - restricted_kv_.first + "'"); - } - step_ = Steps::WRITEREADY; - - if (async_) { - // For asynchronous operation, need to add itself onto the completion - // queue so that the response can be written once the object is - // taken up next for execution. - AddToCompletionQueue(); - } -} -template <typename ResponderType, typename RequestType, typename ResponseType> -bool -CommonCallData<ResponderType, RequestType, ResponseType>::ExecutePrecondition() -{ - if (!restricted_kv_.first.empty()) { - const auto& metadata = ctx_.client_metadata(); - const auto it = metadata.find(restricted_kv_.first); - return (it != metadata.end()) && (it->second == restricted_kv_.second); + ::grpc::Status status; + GrpcStatusUtil::Create(&status, err); + TRITONSERVER_ErrorDelete(err); + reactor->Finish(status); + return reactor; } - return true; -} - -template <typename ResponderType, typename RequestType, typename ResponseType> -void -CommonCallData<ResponderType, RequestType, ResponseType>::AddToCompletionQueue() -{ - alarm_.Set(cq_, gpr_now(gpr_clock_type::GPR_CLOCK_REALTIME), this); -} -template <typename ResponderType, typename RequestType, typename ResponseType> -void -CommonCallData<ResponderType, RequestType, ResponseType>::WriteResponse() -{ - if (response_delay_ != 0) { - // Will delay the write of the response by the specified time. - // This can be used to test the flow where there are other - // responses available to be written. - LOG_VERBOSE(1) << "Delaying the write of the response by " - << response_delay_ << " seconds"; - std::this_thread::sleep_for(std::chrono::seconds(response_delay_)); - } - step_ = Steps::COMPLETE; - responder_.Finish(response_, status_, this); -} + private: + std::shared_ptr<TRITONSERVER_Server> tritonserver_; + RestrictedFeatures restricted_keys_; +}; -// -// CommonHandler -// -// A common handler for all non-inference requests. -// -class CommonHandler : public HandlerBase { +class UnifiedCallbackService + : public inference::GRPCInferenceServiceCallback::CallbackService { public: - CommonHandler( + UnifiedCallbackService( const std::string& name, - const std::shared_ptr<TRITONSERVER_Server>& tritonserver, - const std::shared_ptr<SharedMemoryManager>& shm_manager, + const std::shared_ptr<TRITONSERVER_Server>& server, TraceManager* trace_manager, - inference::GRPCInferenceService::AsyncService* service, - ::grpc::health::v1::Health::AsyncService* health_service, - ::grpc::ServerCompletionQueue* cq, - const RestrictedFeatures& restricted_keys, const uint64_t response_delay); - - // Descriptive name of of the handler. - const std::string& Name() const { return name_; } - - // Start handling requests. - void Start() override; - - // Stop handling requests. - void Stop() override; + const std::shared_ptr<SharedMemoryManager>& shm_manager, + grpc_compression_level compression_level, + RestrictedFeatures& restricted_keys_, + const std::string& forward_header_pattern) + : tritonserver_(server), shm_manager_(shm_manager), + trace_manager_(trace_manager), restricted_keys_(restricted_keys_), + model_infer_handler_( + "ModelInferCallbackHandler", tritonserver_, trace_manager_, + shm_manager_, compression_level, restricted_keys_, + forward_header_pattern) + { + } - private: - void SetUpAllRequests(); - - // [FIXME] turn into generated code - void RegisterServerLive(); - void RegisterServerReady(); - void RegisterHealthCheck(); - void RegisterModelReady(); - void RegisterServerMetadata(); - void RegisterModelMetadata(); - void RegisterModelConfig(); - void RegisterModelStatistics(); - void RegisterTrace(); - void RegisterLogging(); - void RegisterSystemSharedMemoryStatus(); - void RegisterSystemSharedMemoryRegister(); - void RegisterSystemSharedMemoryUnregister(); - void RegisterCudaSharedMemoryStatus(); - void RegisterCudaSharedMemoryRegister(); - void RegisterCudaSharedMemoryUnregister(); - void RegisterRepositoryIndex(); - void RegisterRepositoryModelLoad(); - void RegisterRepositoryModelUnload(); - - // Set count and cumulative duration for 'RegisterModelStatistics()' template <typename PBTYPE> TRITONSERVER_Error* SetStatisticsDuration( triton::common::TritonJson::Value& statistics_json, const std::string& statistics_name, - PBTYPE* mutable_statistics_duration_protobuf) const; - - const std::string name_; - std::shared_ptr<TRITONSERVER_Server> tritonserver_; + PBTYPE* mutable_statistics_duration_protobuf) + { + triton::common::TritonJson::Value statistics_duration_json; + RETURN_IF_ERR(statistics_json.MemberAsObject( + statistics_name.c_str(), &statistics_duration_json)); + + uint64_t value; + RETURN_IF_ERR(statistics_duration_json.MemberAsUInt("count", &value)); + mutable_statistics_duration_protobuf->set_count(value); + RETURN_IF_ERR(statistics_duration_json.MemberAsUInt("ns", &value)); + mutable_statistics_duration_protobuf->set_ns(value); + return nullptr; + } - std::shared_ptr<SharedMemoryManager> shm_manager_; - TraceManager* trace_manager_; + ::grpc::ServerUnaryReactor* ModelInfer( + ::grpc::CallbackServerContext* context, + const inference::ModelInferRequest* request, + inference::ModelInferResponse* response) override + { + // 1. Create reactor for this RPC - This is incorrect for callback API. + // The reactor is obtained from the context, but we don't need it here + // directly. The handler function will obtain and manage it. - inference::GRPCInferenceService::AsyncService* service_; - ::grpc::health::v1::Health::AsyncService* health_service_; - ::grpc::ServerCompletionQueue* cq_; - std::unique_ptr<std::thread> thread_; - RestrictedFeatures restricted_keys_{}; - const uint64_t response_delay_ = 0; -}; + // 2. Process request and start inference by calling the *member handler*. + // The handler function itself returns the reactor. + return model_infer_handler_.HandleModelInfer( + context, request, response); // CORRECTED CALL -CommonHandler::CommonHandler( - const std::string& name, - const std::shared_ptr<TRITONSERVER_Server>& tritonserver, - const std::shared_ptr<SharedMemoryManager>& shm_manager, - TraceManager* trace_manager, - inference::GRPCInferenceService::AsyncService* service, - ::grpc::health::v1::Health::AsyncService* health_service, - ::grpc::ServerCompletionQueue* cq, - const RestrictedFeatures& restricted_keys, - const uint64_t response_delay = 0) - : name_(name), tritonserver_(tritonserver), shm_manager_(shm_manager), - trace_manager_(trace_manager), service_(service), - health_service_(health_service), cq_(cq), - restricted_keys_(restricted_keys), response_delay_(response_delay) -{ -} + // 3. Return reactor to gRPC - Handled by returning the result of the line + // above. + } -void -CommonHandler::Start() -{ - // Use a barrier to make sure we don't return until thread has - // started. - auto barrier = std::make_shared<Barrier>(2); - - thread_.reset(new std::thread([this, barrier] { - SetUpAllRequests(); - barrier->Wait(); - - void* tag; - bool ok; - - while (cq_->Next(&tag, &ok)) { - ICallData* call_data = static_cast<ICallData*>(tag); - if (!call_data->Process(ok)) { - LOG_VERBOSE(1) << "Done for " << call_data->Name() << ", " - << call_data->Id(); - delete call_data; + // Example RPC method: ServerLive + ::grpc::ServerUnaryReactor* ServerLive( + ::grpc::CallbackServerContext* context, + const inference::ServerLiveRequest* request, + inference::ServerLiveResponse* response) override + { + auto* reactor = context->DefaultReactor(); + + // (Optionally) Check client metadata for restricted access. + const std::pair<std::string, std::string>& restricted_kv = + restricted_keys_.Get(RestrictedCategory::HEALTH); + if (!restricted_kv.first.empty()) { + const auto& metadata = context->client_metadata(); + auto it = metadata.find(restricted_kv.first); + if (it == metadata.end() || it->second != restricted_kv.second) { + reactor->Finish(::grpc::Status( + ::grpc::StatusCode::UNAVAILABLE, + "Missing or mismatched restricted header")); + return reactor; } } - })); - - barrier->Wait(); - LOG_VERBOSE(1) << "Thread started for " << Name(); -} - -void -CommonHandler::Stop() -{ - if (thread_->joinable()) { - thread_->join(); - } - LOG_VERBOSE(1) << "Thread exited for " << Name(); -} - -void -CommonHandler::SetUpAllRequests() -{ - // Define all the RPCs to be handled by this handler below - // - // Within each of the Register function, the format of RPC specification is: - // 1. A OnRegister function: This will be called when the - // server is ready to receive the requests for this RPC. - // 2. A OnExecute function: This will be called when the - // to process the request. - // 3. Create a CommonCallData object with the above callback - // functions - - // health (GRPC standard) - RegisterHealthCheck(); - // health (Triton) - RegisterServerLive(); - RegisterServerReady(); - RegisterModelReady(); - - // Metadata - RegisterServerMetadata(); - RegisterModelMetadata(); - - // model config - RegisterModelConfig(); - - // shared memory - // system.. - RegisterSystemSharedMemoryStatus(); - RegisterSystemSharedMemoryRegister(); - RegisterSystemSharedMemoryUnregister(); - // cuda.. - RegisterCudaSharedMemoryStatus(); - RegisterCudaSharedMemoryRegister(); - RegisterCudaSharedMemoryUnregister(); - - // model repository - RegisterRepositoryIndex(); - RegisterRepositoryModelLoad(); - RegisterRepositoryModelUnload(); - - // statistics - RegisterModelStatistics(); - - // trace - RegisterTrace(); - - // logging - RegisterLogging(); -} - -void -CommonHandler::RegisterServerLive() -{ - auto OnRegisterServerLive = - [this]( - ::grpc::ServerContext* ctx, inference::ServerLiveRequest* request, - ::grpc::ServerAsyncResponseWriter<inference::ServerLiveResponse>* - responder, - void* tag) { - this->service_->RequestServerLive( - ctx, request, responder, this->cq_, this->cq_, tag); - }; - - auto OnExecuteServerLive = [this]( - inference::ServerLiveRequest& request, - inference::ServerLiveResponse* response, - ::grpc::Status* status) { + // Business logic for ServerLive. bool live = false; TRITONSERVER_Error* err = TRITONSERVER_ServerIsLive(tritonserver_.get(), &live); - response->set_live((err == nullptr) && live); - GrpcStatusUtil::Create(status, err); + ::grpc::Status status; + GrpcStatusUtil::Create(&status, err); TRITONSERVER_ErrorDelete(err); - }; - - const std::pair<std::string, std::string>& restricted_kv = - restricted_keys_.Get(RestrictedCategory::HEALTH); - new CommonCallData< - ::grpc::ServerAsyncResponseWriter<inference::ServerLiveResponse>, - inference::ServerLiveRequest, inference::ServerLiveResponse>( - "ServerLive", 0, OnRegisterServerLive, OnExecuteServerLive, - false /* async */, cq_, restricted_kv, response_delay_); -} + reactor->Finish(status); + return reactor; + } -void -CommonHandler::RegisterServerReady() -{ - auto OnRegisterServerReady = - [this]( - ::grpc::ServerContext* ctx, inference::ServerReadyRequest* request, - ::grpc::ServerAsyncResponseWriter<inference::ServerReadyResponse>* - responder, - void* tag) { - this->service_->RequestServerReady( - ctx, request, responder, this->cq_, this->cq_, tag); - }; - - auto OnExecuteServerReady = [this]( - inference::ServerReadyRequest& request, - inference::ServerReadyResponse* response, - ::grpc::Status* status) { + ::grpc::ServerUnaryReactor* ServerReady( + ::grpc::CallbackServerContext* context, + const inference::ServerReadyRequest* request, + inference::ServerReadyResponse* response) override + { + auto* reactor = context->DefaultReactor(); + + // (Optionally) Check client metadata for restricted access. + const std::pair<std::string, std::string>& restricted_kv = + restricted_keys_.Get(RestrictedCategory::HEALTH); + if (!restricted_kv.first.empty()) { + const auto& metadata = context->client_metadata(); + auto it = metadata.find(restricted_kv.first); + if (it == metadata.end() || it->second != restricted_kv.second) { + reactor->Finish(::grpc::Status( + ::grpc::StatusCode::UNAVAILABLE, + "Missing or mismatched restricted header")); + return reactor; + } + } + + // Business logic for ServerReady. bool ready = false; TRITONSERVER_Error* err = TRITONSERVER_ServerIsReady(tritonserver_.get(), &ready); - response->set_ready((err == nullptr) && ready); - GrpcStatusUtil::Create(status, err); + ::grpc::Status status; + GrpcStatusUtil::Create(&status, err); TRITONSERVER_ErrorDelete(err); - }; - - const std::pair<std::string, std::string>& restricted_kv = - restricted_keys_.Get(RestrictedCategory::HEALTH); - new CommonCallData< - ::grpc::ServerAsyncResponseWriter<inference::ServerReadyResponse>, - inference::ServerReadyRequest, inference::ServerReadyResponse>( - "ServerReady", 0, OnRegisterServerReady, OnExecuteServerReady, - false /* async */, cq_, restricted_kv, response_delay_); -} - -void -CommonHandler::RegisterHealthCheck() -{ - auto OnRegisterHealthCheck = - [this]( - ::grpc::ServerContext* ctx, - ::grpc::health::v1::HealthCheckRequest* request, - ::grpc::ServerAsyncResponseWriter< - ::grpc::health::v1::HealthCheckResponse>* responder, - void* tag) { - this->health_service_->RequestCheck( - ctx, request, responder, this->cq_, this->cq_, tag); - }; - - auto OnExecuteHealthCheck = [this]( - ::grpc::health::v1::HealthCheckRequest& - request, - ::grpc::health::v1::HealthCheckResponse* - response, - ::grpc::Status* status) { - bool live = false; - TRITONSERVER_Error* err = - TRITONSERVER_ServerIsReady(tritonserver_.get(), &live); + reactor->Finish(status); + return reactor; + } - auto serving_status = - ::grpc::health::v1::HealthCheckResponse_ServingStatus_UNKNOWN; - if (err == nullptr) { - serving_status = - live ? ::grpc::health::v1::HealthCheckResponse_ServingStatus_SERVING - : ::grpc::health::v1:: - HealthCheckResponse_ServingStatus_NOT_SERVING; + ::grpc::ServerUnaryReactor* ModelReady( + ::grpc::CallbackServerContext* context, + const inference::ModelReadyRequest* request, + inference::ModelReadyResponse* response) override + { + auto* reactor = context->DefaultReactor(); + + // (Optionally) Check client metadata for restricted access. + const std::pair<std::string, std::string>& restricted_kv = + restricted_keys_.Get(RestrictedCategory::HEALTH); + if (!restricted_kv.first.empty()) { + const auto& metadata = context->client_metadata(); + auto it = metadata.find(restricted_kv.first); + if (it == metadata.end() || it->second != restricted_kv.second) { + reactor->Finish(::grpc::Status( + ::grpc::StatusCode::UNAVAILABLE, + "Missing or mismatched restricted header")); + return reactor; + } } - response->set_status(serving_status); - GrpcStatusUtil::Create(status, err); - TRITONSERVER_ErrorDelete(err); - }; - - const std::pair<std::string, std::string>& restricted_kv = - restricted_keys_.Get(RestrictedCategory::HEALTH); - new CommonCallData< - ::grpc::ServerAsyncResponseWriter< - ::grpc::health::v1::HealthCheckResponse>, - ::grpc::health::v1::HealthCheckRequest, - ::grpc::health::v1::HealthCheckResponse>( - "Check", 0, OnRegisterHealthCheck, OnExecuteHealthCheck, - false /* async */, cq_, restricted_kv, response_delay_); -} - -void -CommonHandler::RegisterModelReady() -{ - auto OnRegisterModelReady = - [this]( - ::grpc::ServerContext* ctx, inference::ModelReadyRequest* request, - ::grpc::ServerAsyncResponseWriter<inference::ModelReadyResponse>* - responder, - void* tag) { - this->service_->RequestModelReady( - ctx, request, responder, this->cq_, this->cq_, tag); - }; - - auto OnExecuteModelReady = [this]( - inference::ModelReadyRequest& request, - inference::ModelReadyResponse* response, - ::grpc::Status* status) { + // Business logic for ModelReady. bool is_ready = false; int64_t requested_model_version; - auto err = - GetModelVersionFromString(request.version(), &requested_model_version); + TRITONSERVER_Error* err = + GetModelVersionFromString(request->version(), &requested_model_version); if (err == nullptr) { err = TRITONSERVER_ServerModelIsReady( - tritonserver_.get(), request.name().c_str(), requested_model_version, + tritonserver_.get(), request->name().c_str(), requested_model_version, &is_ready); } response->set_ready(is_ready); - GrpcStatusUtil::Create(status, err); + ::grpc::Status status; + GrpcStatusUtil::Create(&status, err); TRITONSERVER_ErrorDelete(err); - }; - - const std::pair<std::string, std::string>& restricted_kv = - restricted_keys_.Get(RestrictedCategory::HEALTH); - new CommonCallData< - ::grpc::ServerAsyncResponseWriter<inference::ModelReadyResponse>, - inference::ModelReadyRequest, inference::ModelReadyResponse>( - "ModelReady", 0, OnRegisterModelReady, OnExecuteModelReady, - false /* async */, cq_, restricted_kv, response_delay_); -} - -void -CommonHandler::RegisterServerMetadata() -{ - auto OnRegisterServerMetadata = - [this]( - ::grpc::ServerContext* ctx, inference::ServerMetadataRequest* request, - ::grpc::ServerAsyncResponseWriter<inference::ServerMetadataResponse>* - responder, - void* tag) { - this->service_->RequestServerMetadata( - ctx, request, responder, this->cq_, this->cq_, tag); - }; - - auto OnExecuteServerMetadata = - [this]( - inference::ServerMetadataRequest& request, - inference::ServerMetadataResponse* response, ::grpc::Status* status) { - TRITONSERVER_Message* server_metadata_message = nullptr; - TRITONSERVER_Error* err = TRITONSERVER_ServerMetadata( - tritonserver_.get(), &server_metadata_message); - GOTO_IF_ERR(err, earlyexit); - - const char* buffer; - size_t byte_size; - err = TRITONSERVER_MessageSerializeToJson( - server_metadata_message, &buffer, &byte_size); - GOTO_IF_ERR(err, earlyexit); + reactor->Finish(status); + return reactor; + } - { - triton::common::TritonJson::Value server_metadata_json; - err = server_metadata_json.Parse(buffer, byte_size); - GOTO_IF_ERR(err, earlyexit); + ::grpc::ServerUnaryReactor* ServerMetadata( + ::grpc::CallbackServerContext* context, + const inference::ServerMetadataRequest* request, + inference::ServerMetadataResponse* response) override + { + auto* reactor = context->DefaultReactor(); + + // (Optionally) Check client metadata for restricted access. + const std::pair<std::string, std::string>& restricted_kv = + restricted_keys_.Get(RestrictedCategory::METADATA); + if (!restricted_kv.first.empty()) { + const auto& metadata = context->client_metadata(); + auto it = metadata.find(restricted_kv.first); + if (it == metadata.end() || it->second != restricted_kv.second) { + reactor->Finish(::grpc::Status( + ::grpc::StatusCode::UNAVAILABLE, + "Missing or mismatched restricted header")); + return reactor; + } + } + // Business logic for ServerMetadata. + TRITONSERVER_Message* server_metadata_message = nullptr; + TRITONSERVER_Error* err = TRITONSERVER_ServerMetadata( + tritonserver_.get(), &server_metadata_message); + if (err == nullptr) { + const char* buffer; + size_t byte_size; + err = TRITONSERVER_MessageSerializeToJson( + server_metadata_message, &buffer, &byte_size); + if (err == nullptr) { + triton::common::TritonJson::Value server_metadata_json; + err = server_metadata_json.Parse(buffer, byte_size); + if (err == nullptr) { const char* name; size_t namelen; err = server_metadata_json.MemberAsString("name", &name, &namelen); - GOTO_IF_ERR(err, earlyexit); - - const char* version; - size_t versionlen; - err = server_metadata_json.MemberAsString( - "version", &version, &versionlen); - GOTO_IF_ERR(err, earlyexit); - - response->set_name(std::string(name, namelen)); - response->set_version(std::string(version, versionlen)); - - if (server_metadata_json.Find("extensions")) { - triton::common::TritonJson::Value extensions_json; - err = server_metadata_json.MemberAsArray( - "extensions", &extensions_json); - GOTO_IF_ERR(err, earlyexit); - - for (size_t idx = 0; idx < extensions_json.ArraySize(); ++idx) { - const char* ext; - size_t extlen; - err = extensions_json.IndexAsString(idx, &ext, &extlen); - GOTO_IF_ERR(err, earlyexit); - response->add_extensions(std::string(ext, extlen)); + if (err == nullptr) { + const char* version; + size_t versionlen; + err = server_metadata_json.MemberAsString( + "version", &version, &versionlen); + if (err == nullptr) { + response->set_name(std::string(name, namelen)); + response->set_version(std::string(version, versionlen)); + + if (server_metadata_json.Find("extensions")) { + triton::common::TritonJson::Value extensions_json; + err = server_metadata_json.MemberAsArray( + "extensions", &extensions_json); + if (err == nullptr) { + for (size_t idx = 0; idx < extensions_json.ArraySize(); + ++idx) { + const char* ext; + size_t extlen; + err = extensions_json.IndexAsString(idx, &ext, &extlen); + if (err == nullptr) { + response->add_extensions(std::string(ext, extlen)); + } + } + } + } } } - TRITONSERVER_MessageDelete(server_metadata_message); } + } + TRITONSERVER_MessageDelete(server_metadata_message); + } - earlyexit: - GrpcStatusUtil::Create(status, err); - TRITONSERVER_ErrorDelete(err); - }; - - const std::pair<std::string, std::string>& restricted_kv = - restricted_keys_.Get(RestrictedCategory::METADATA); - new CommonCallData< - ::grpc::ServerAsyncResponseWriter<inference::ServerMetadataResponse>, - inference::ServerMetadataRequest, inference::ServerMetadataResponse>( - "ServerMetadata", 0, OnRegisterServerMetadata, OnExecuteServerMetadata, - false /* async */, cq_, restricted_kv, response_delay_); -} + ::grpc::Status status; + GrpcStatusUtil::Create(&status, err); + TRITONSERVER_ErrorDelete(err); + reactor->Finish(status); + return reactor; + } -void -CommonHandler::RegisterModelMetadata() -{ - auto OnRegisterModelMetadata = - [this]( - ::grpc::ServerContext* ctx, inference::ModelMetadataRequest* request, - ::grpc::ServerAsyncResponseWriter<inference::ModelMetadataResponse>* - responder, - void* tag) { - this->service_->RequestModelMetadata( - ctx, request, responder, this->cq_, this->cq_, tag); - }; - - auto OnExecuteModelMetadata = [this]( - inference::ModelMetadataRequest& request, - inference::ModelMetadataResponse* response, - ::grpc::Status* status) { + ::grpc::ServerUnaryReactor* ModelMetadata( + ::grpc::CallbackServerContext* context, + const inference::ModelMetadataRequest* request, + inference::ModelMetadataResponse* response) override + { + auto* reactor = context->DefaultReactor(); + + // (Optionally) Check client metadata for restricted access. + const std::pair<std::string, std::string>& restricted_kv = + restricted_keys_.Get(RestrictedCategory::METADATA); + if (!restricted_kv.first.empty()) { + const auto& metadata = context->client_metadata(); + auto it = metadata.find(restricted_kv.first); + if (it == metadata.end() || it->second != restricted_kv.second) { + reactor->Finish(::grpc::Status( + ::grpc::StatusCode::UNAVAILABLE, + "Missing or mismatched restricted header")); + return reactor; + } + } + + // Core business logic - kept same as original int64_t requested_model_version; auto err = - GetModelVersionFromString(request.version(), &requested_model_version); + GetModelVersionFromString(request->version(), &requested_model_version); GOTO_IF_ERR(err, earlyexit); - { TRITONSERVER_Message* model_metadata_message = nullptr; err = TRITONSERVER_ServerModelMetadata( - tritonserver_.get(), request.name().c_str(), requested_model_version, + tritonserver_.get(), request->name().c_str(), requested_model_version, &model_metadata_message); GOTO_IF_ERR(err, earlyexit); @@ -769,7 +482,6 @@ CommonHandler::RegisterModelMetadata() int64_t d; err = shape_json.IndexAsInt(sidx, &d); GOTO_IF_ERR(err, earlyexit); - io->add_shape(d); } } @@ -811,54 +523,51 @@ CommonHandler::RegisterModelMetadata() int64_t d; err = shape_json.IndexAsInt(sidx, &d); GOTO_IF_ERR(err, earlyexit); - io->add_shape(d); } } } } - TRITONSERVER_MessageDelete(model_metadata_message); } - earlyexit: - GrpcStatusUtil::Create(status, err); + + ::grpc::Status status; + GrpcStatusUtil::Create(&status, err); TRITONSERVER_ErrorDelete(err); - }; - - const std::pair<std::string, std::string>& restricted_kv = - restricted_keys_.Get(RestrictedCategory::METADATA); - new CommonCallData< - ::grpc::ServerAsyncResponseWriter<inference::ModelMetadataResponse>, - inference::ModelMetadataRequest, inference::ModelMetadataResponse>( - "ModelMetadata", 0, OnRegisterModelMetadata, OnExecuteModelMetadata, - false /* async */, cq_, restricted_kv, response_delay_); -} + reactor->Finish(status); + return reactor; + } -void -CommonHandler::RegisterModelConfig() -{ - auto OnRegisterModelConfig = - [this]( - ::grpc::ServerContext* ctx, inference::ModelConfigRequest* request, - ::grpc::ServerAsyncResponseWriter<inference::ModelConfigResponse>* - responder, - void* tag) { - this->service_->RequestModelConfig( - ctx, request, responder, this->cq_, this->cq_, tag); - }; - - auto OnExecuteModelConfig = [this]( - inference::ModelConfigRequest& request, - inference::ModelConfigResponse* response, - ::grpc::Status* status) { + ::grpc::ServerUnaryReactor* ModelConfig( + ::grpc::CallbackServerContext* context, + const inference::ModelConfigRequest* request, + inference::ModelConfigResponse* response) override + { + auto* reactor = context->DefaultReactor(); + + // (Optionally) Check client metadata for restricted access. + const std::pair<std::string, std::string>& restricted_kv = + restricted_keys_.Get(RestrictedCategory::MODEL_CONFIG); + if (!restricted_kv.first.empty()) { + const auto& metadata = context->client_metadata(); + auto it = metadata.find(restricted_kv.first); + if (it == metadata.end() || it->second != restricted_kv.second) { + reactor->Finish(::grpc::Status( + ::grpc::StatusCode::UNAVAILABLE, + "Missing or mismatched restricted header")); + return reactor; + } + } + + // Core business logic int64_t requested_model_version; auto err = - GetModelVersionFromString(request.version(), &requested_model_version); + GetModelVersionFromString(request->version(), &requested_model_version); if (err == nullptr) { TRITONSERVER_Message* model_config_message = nullptr; err = TRITONSERVER_ServerModelConfig( - tritonserver_.get(), request.name().c_str(), requested_model_version, + tritonserver_.get(), request->name().c_str(), requested_model_version, 1 /* config_version */, &model_config_message); if (err == nullptr) { const char* buffer; @@ -875,51 +584,48 @@ CommonHandler::RegisterModelConfig() } } - GrpcStatusUtil::Create(status, err); + ::grpc::Status status; + GrpcStatusUtil::Create(&status, err); TRITONSERVER_ErrorDelete(err); - }; - - const std::pair<std::string, std::string>& restricted_kv = - restricted_keys_.Get(RestrictedCategory::MODEL_CONFIG); - new CommonCallData< - ::grpc::ServerAsyncResponseWriter<inference::ModelConfigResponse>, - inference::ModelConfigRequest, inference::ModelConfigResponse>( - "ModelConfig", 0, OnRegisterModelConfig, OnExecuteModelConfig, - false /* async */, cq_, restricted_kv, response_delay_); -} + reactor->Finish(status); + return reactor; + } -void -CommonHandler::RegisterModelStatistics() -{ - auto OnRegisterModelStatistics = - [this]( - ::grpc::ServerContext* ctx, - inference::ModelStatisticsRequest* request, - ::grpc::ServerAsyncResponseWriter<inference::ModelStatisticsResponse>* - responder, - void* tag) { - this->service_->RequestModelStatistics( - ctx, request, responder, this->cq_, this->cq_, tag); - }; - - auto OnExecuteModelStatistics = [this]( - inference::ModelStatisticsRequest& - request, - inference::ModelStatisticsResponse* - response, - ::grpc::Status* status) { + // Other RPC methods (e.g., ServerReady, HealthCheck) would be implemented + // similarly. + ::grpc::ServerUnaryReactor* ModelStatistics( + ::grpc::CallbackServerContext* context, + const inference::ModelStatisticsRequest* request, + inference::ModelStatisticsResponse* response) override + { + auto* reactor = context->DefaultReactor(); + + // (Optionally) Check client metadata for restricted access. + const std::pair<std::string, std::string>& restricted_kv = + restricted_keys_.Get(RestrictedCategory::STATISTICS); + if (!restricted_kv.first.empty()) { + const auto& metadata = context->client_metadata(); + auto it = metadata.find(restricted_kv.first); + if (it == metadata.end() || it->second != restricted_kv.second) { + reactor->Finish(::grpc::Status( + ::grpc::StatusCode::UNAVAILABLE, + "Missing or mismatched restricted header")); + return reactor; + } + } + + // Core business logic - kept same as original #ifdef TRITON_ENABLE_STATS triton::common::TritonJson::Value model_stats_json; int64_t requested_model_version; auto err = - GetModelVersionFromString(request.version(), &requested_model_version); + GetModelVersionFromString(request->version(), &requested_model_version); GOTO_IF_ERR(err, earlyexit); - { TRITONSERVER_Message* model_stats_message = nullptr; err = TRITONSERVER_ServerModelStatistics( - tritonserver_.get(), request.name().c_str(), requested_model_version, + tritonserver_.get(), request->name().c_str(), requested_model_version, &model_stats_message); GOTO_IF_ERR(err, earlyexit); @@ -1118,63 +824,44 @@ CommonHandler::RegisterModelStatistics() } earlyexit: - GrpcStatusUtil::Create(status, err); + ::grpc::Status status; + GrpcStatusUtil::Create(&status, err); TRITONSERVER_ErrorDelete(err); #else auto err = TRITONSERVER_ErrorNew( TRITONSERVER_ERROR_UNAVAILABLE, "the server does not support model statistics"); - GrpcStatusUtil::Create(status, err); + ::grpc::Status status; + GrpcStatusUtil::Create(&status, err); TRITONSERVER_ErrorDelete(err); #endif - }; - - const std::pair<std::string, std::string>& restricted_kv = - restricted_keys_.Get(RestrictedCategory::STATISTICS); - new CommonCallData< - ::grpc::ServerAsyncResponseWriter<inference::ModelStatisticsResponse>, - inference::ModelStatisticsRequest, inference::ModelStatisticsResponse>( - "ModelStatistics", 0, OnRegisterModelStatistics, OnExecuteModelStatistics, - false /* async */, cq_, restricted_kv, response_delay_); -} - -template <typename PBTYPE> -TRITONSERVER_Error* -CommonHandler::SetStatisticsDuration( - triton::common::TritonJson::Value& statistics_json, - const std::string& statistics_name, - PBTYPE* mutable_statistics_duration_protobuf) const -{ - triton::common::TritonJson::Value statistics_duration_json; - RETURN_IF_ERR(statistics_json.MemberAsObject( - statistics_name.c_str(), &statistics_duration_json)); - uint64_t value; - RETURN_IF_ERR(statistics_duration_json.MemberAsUInt("count", &value)); - mutable_statistics_duration_protobuf->set_count(value); - RETURN_IF_ERR(statistics_duration_json.MemberAsUInt("ns", &value)); - mutable_statistics_duration_protobuf->set_ns(value); + reactor->Finish(status); + return reactor; + } - return nullptr; -} + ::grpc::ServerUnaryReactor* TraceSetting( + ::grpc::CallbackServerContext* context, + const inference::TraceSettingRequest* request, + inference::TraceSettingResponse* response) override + { + auto* reactor = context->DefaultReactor(); + + // (Optionally) Check client metadata for restricted access. + const std::pair<std::string, std::string>& restricted_kv = + restricted_keys_.Get(RestrictedCategory::TRACE); + if (!restricted_kv.first.empty()) { + const auto& metadata = context->client_metadata(); + auto it = metadata.find(restricted_kv.first); + if (it == metadata.end() || it->second != restricted_kv.second) { + reactor->Finish(::grpc::Status( + ::grpc::StatusCode::UNAVAILABLE, + "Missing or mismatched restricted header")); + return reactor; + } + } -void -CommonHandler::RegisterTrace() -{ - auto OnRegisterTrace = - [this]( - ::grpc::ServerContext* ctx, inference::TraceSettingRequest* request, - ::grpc::ServerAsyncResponseWriter<inference::TraceSettingResponse>* - responder, - void* tag) { - this->service_->RequestTraceSetting( - ctx, request, responder, this->cq_, this->cq_, tag); - }; - - auto OnExecuteTrace = [this]( - inference::TraceSettingRequest& request, - inference::TraceSettingResponse* response, - ::grpc::Status* status) { + // Core business logic - kept same as original #ifdef TRITON_ENABLE_TRACING TRITONSERVER_Error* err = nullptr; TRITONSERVER_InferenceTraceLevel level = TRITONSERVER_TRACE_LEVEL_DISABLED; @@ -1185,29 +872,29 @@ CommonHandler::RegisterTrace() InferenceTraceMode trace_mode; TraceConfigMap config_map; - if (!request.model_name().empty()) { + if (!request->model_name().empty()) { bool ready = false; - GOTO_IF_ERR( - TRITONSERVER_ServerModelIsReady( - tritonserver_.get(), request.model_name().c_str(), - -1 /* model version */, &ready), - earlyexit); + err = TRITONSERVER_ServerModelIsReady( + tritonserver_.get(), request->model_name().c_str(), + -1 /* model version */, &ready); + GOTO_IF_ERR(err, earlyexit); if (!ready) { err = TRITONSERVER_ErrorNew( TRITONSERVER_ERROR_INVALID_ARG, - (std::string("Request for unknown model : ") + request.model_name()) + (std::string("Request for unknown model : ") + + request->model_name()) .c_str()); GOTO_IF_ERR(err, earlyexit); } } // Update trace setting - if (!request.settings().empty()) { + if (!request->settings().empty()) { TraceManager::NewSetting new_setting; { static std::string setting_name = "trace_file"; - auto it = request.settings().find(setting_name); - if (it != request.settings().end()) { + auto it = request->settings().find(setting_name); + if (it != request->settings().end()) { err = TRITONSERVER_ErrorNew( TRITONSERVER_ERROR_UNSUPPORTED, "trace file location can not be updated through network " @@ -1217,8 +904,8 @@ CommonHandler::RegisterTrace() } { static std::string setting_name = "trace_level"; - auto it = request.settings().find(setting_name); - if (it != request.settings().end()) { + auto it = request->settings().find(setting_name); + if (it != request->settings().end()) { if (it->second.value().size() == 0) { new_setting.clear_level_ = true; } else { @@ -1248,8 +935,8 @@ CommonHandler::RegisterTrace() } { static std::string setting_name = "trace_rate"; - auto it = request.settings().find(setting_name); - if (it != request.settings().end()) { + auto it = request->settings().find(setting_name); + if (it != request->settings().end()) { if (it->second.value().size() == 0) { new_setting.clear_rate_ = true; } else if (it->second.value().size() == 1) { @@ -1288,8 +975,8 @@ CommonHandler::RegisterTrace() } { static std::string setting_name = "trace_count"; - auto it = request.settings().find(setting_name); - if (it != request.settings().end()) { + auto it = request->settings().find(setting_name); + if (it != request->settings().end()) { if (it->second.value().size() == 0) { new_setting.clear_count_ = true; } else if (it->second.value().size() == 1) { @@ -1337,8 +1024,8 @@ CommonHandler::RegisterTrace() } { static std::string setting_name = "log_frequency"; - auto it = request.settings().find(setting_name); - if (it != request.settings().end()) { + auto it = request->settings().find(setting_name); + if (it != request->settings().end()) { if (it->second.value().size() == 0) { new_setting.clear_log_frequency_ = true; } else if (it->second.value().size() == 1) { @@ -1376,16 +1063,16 @@ CommonHandler::RegisterTrace() } } - err = - trace_manager_->UpdateTraceSetting(request.model_name(), new_setting); + err = trace_manager_->UpdateTraceSetting( + request->model_name(), new_setting); GOTO_IF_ERR(err, earlyexit); } - // Get current trace setting, this is needed even if the setting - // has been updated above as some values may not be provided in the request. + // Get current trace setting trace_manager_->GetTraceSetting( - request.model_name(), &level, &rate, &count, &log_frequency, &filepath, + request->model_name(), &level, &rate, &count, &log_frequency, &filepath, &trace_mode, &config_map); + // level { inference::TraceSettingResponse::SettingValue level_setting; @@ -1401,6 +1088,7 @@ CommonHandler::RegisterTrace() } (*response->mutable_settings())["trace_level"] = level_setting; } + (*response->mutable_settings())["trace_rate"].add_value( std::to_string(rate)); (*response->mutable_settings())["trace_count"].add_value( @@ -1432,54 +1120,55 @@ CommonHandler::RegisterTrace() } } } + earlyexit: - GrpcStatusUtil::Create(status, err); + ::grpc::Status status; + GrpcStatusUtil::Create(&status, err); TRITONSERVER_ErrorDelete(err); #else auto err = TRITONSERVER_ErrorNew( TRITONSERVER_ERROR_UNAVAILABLE, "the server does not support trace"); - GrpcStatusUtil::Create(status, err); + ::grpc::Status status; + GrpcStatusUtil::Create(&status, err); TRITONSERVER_ErrorDelete(err); #endif - }; - - const std::pair<std::string, std::string>& restricted_kv = - restricted_keys_.Get(RestrictedCategory::TRACE); - new CommonCallData< - ::grpc::ServerAsyncResponseWriter<inference::TraceSettingResponse>, - inference::TraceSettingRequest, inference::TraceSettingResponse>( - "Trace", 0, OnRegisterTrace, OnExecuteTrace, false /* async */, cq_, - restricted_kv, response_delay_); -} -void -CommonHandler::RegisterLogging() -{ - auto OnRegisterLogging = - [this]( - ::grpc::ServerContext* ctx, inference::LogSettingsRequest* request, - ::grpc::ServerAsyncResponseWriter<inference::LogSettingsResponse>* - responder, - void* tag) { - this->service_->RequestLogSettings( - ctx, request, responder, this->cq_, this->cq_, tag); - }; - - auto OnExecuteLogging = [this]( - inference::LogSettingsRequest& request, - inference::LogSettingsResponse* response, - ::grpc::Status* status) { + reactor->Finish(status); + return reactor; + } + + ::grpc::ServerUnaryReactor* LogSettings( + ::grpc::CallbackServerContext* context, + const inference::LogSettingsRequest* request, + inference::LogSettingsResponse* response) override + { + auto* reactor = context->DefaultReactor(); + + // (Optionally) Check client metadata for restricted access. + const std::pair<std::string, std::string>& restricted_kv = + restricted_keys_.Get(RestrictedCategory::LOGGING); + if (!restricted_kv.first.empty()) { + const auto& metadata = context->client_metadata(); + auto it = metadata.find(restricted_kv.first); + if (it == metadata.end() || it->second != restricted_kv.second) { + reactor->Finish(::grpc::Status( + ::grpc::StatusCode::UNAVAILABLE, + "Missing or mismatched restricted header")); + return reactor; + } + } + // Core business logic - kept same as original #ifdef TRITON_ENABLE_LOGGING TRITONSERVER_Error* err = nullptr; // Update log settings // Server and Core repos do not have the same Logger object // Each update must be applied to both server and core repo versions - if (!request.settings().empty()) { + if (!request->settings().empty()) { { static std::string setting_name = "log_file"; - auto it = request.settings().find(setting_name); - if (it != request.settings().end()) { + auto it = request->settings().find(setting_name); + if (it != request->settings().end()) { err = TRITONSERVER_ErrorNew( TRITONSERVER_ERROR_UNSUPPORTED, "log file location can not be updated through network protocol"); @@ -1488,8 +1177,8 @@ CommonHandler::RegisterLogging() } { static std::string setting_name = "log_info"; - auto it = request.settings().find(setting_name); - if (it != request.settings().end()) { + auto it = request->settings().find(setting_name); + if (it != request->settings().end()) { const auto& log_param = it->second; if (log_param.parameter_choice_case() != inference::LogSettingsRequest_SettingValue::ParameterChoiceCase:: @@ -1508,8 +1197,8 @@ CommonHandler::RegisterLogging() } { static std::string setting_name = "log_warning"; - auto it = request.settings().find(setting_name); - if (it != request.settings().end()) { + auto it = request->settings().find(setting_name); + if (it != request->settings().end()) { const auto& log_param = it->second; if (log_param.parameter_choice_case() != inference::LogSettingsRequest_SettingValue::ParameterChoiceCase:: @@ -1528,8 +1217,8 @@ CommonHandler::RegisterLogging() } { static std::string setting_name = "log_error"; - auto it = request.settings().find(setting_name); - if (it != request.settings().end()) { + auto it = request->settings().find(setting_name); + if (it != request->settings().end()) { const auto& log_param = it->second; if (log_param.parameter_choice_case() != inference::LogSettingsRequest_SettingValue::ParameterChoiceCase:: @@ -1548,8 +1237,8 @@ CommonHandler::RegisterLogging() } { static std::string setting_name = "log_verbose_level"; - auto it = request.settings().find(setting_name); - if (it != request.settings().end()) { + auto it = request->settings().find(setting_name); + if (it != request->settings().end()) { const auto& log_param = it->second; if (log_param.parameter_choice_case() != inference::LogSettingsRequest_SettingValue::ParameterChoiceCase:: @@ -1568,8 +1257,8 @@ CommonHandler::RegisterLogging() } { static std::string setting_name = "log_format"; - auto it = request.settings().find(setting_name); - if (it != request.settings().end()) { + auto it = request->settings().find(setting_name); + if (it != request->settings().end()) { const auto& log_param = it->second; if (log_param.parameter_choice_case() != inference::LogSettingsRequest_SettingValue::ParameterChoiceCase:: @@ -1608,6 +1297,7 @@ CommonHandler::RegisterLogging() } GOTO_IF_ERR(err, earlyexit); } + (*response->mutable_settings())["log_file"].set_string_param(LOG_FILE); (*response->mutable_settings())["log_info"].set_bool_param(LOG_INFO_IS_ON); (*response->mutable_settings())["log_warning"].set_bool_param( @@ -1618,628 +1308,688 @@ CommonHandler::RegisterLogging() LOG_VERBOSE_LEVEL); (*response->mutable_settings())["log_format"].set_string_param( LOG_FORMAT_STRING); + earlyexit: - GrpcStatusUtil::Create(status, err); + ::grpc::Status status; + GrpcStatusUtil::Create(&status, err); TRITONSERVER_ErrorDelete(err); #else auto err = TRITONSERVER_ErrorNew( TRITONSERVER_ERROR_UNAVAILABLE, "the server does not support dynamic logging"); - GrpcStatusUtil::Create(status, err); + ::grpc::Status status; + GrpcStatusUtil::Create(&status, err); TRITONSERVER_ErrorDelete(err); #endif - }; - - const std::pair<std::string, std::string>& restricted_kv = - restricted_keys_.Get(RestrictedCategory::LOGGING); - new CommonCallData< - ::grpc::ServerAsyncResponseWriter<inference::LogSettingsResponse>, - inference::LogSettingsRequest, inference::LogSettingsResponse>( - "Logging", 0, OnRegisterLogging, OnExecuteLogging, false /* async */, cq_, - restricted_kv, response_delay_); -} -void -CommonHandler::RegisterSystemSharedMemoryStatus() -{ - auto OnRegisterSystemSharedMemoryStatus = - [this]( - ::grpc::ServerContext* ctx, - inference::SystemSharedMemoryStatusRequest* request, - ::grpc::ServerAsyncResponseWriter< - inference::SystemSharedMemoryStatusResponse>* responder, - void* tag) { - this->service_->RequestSystemSharedMemoryStatus( - ctx, request, responder, this->cq_, this->cq_, tag); - }; - - auto OnExecuteSystemSharedMemoryStatus = - [this]( - inference::SystemSharedMemoryStatusRequest& request, - inference::SystemSharedMemoryStatusResponse* response, - ::grpc::Status* status) { - triton::common::TritonJson::Value shm_status_json( - triton::common::TritonJson::ValueType::ARRAY); - TRITONSERVER_Error* err = shm_manager_->GetStatus( - request.name(), TRITONSERVER_MEMORY_CPU, &shm_status_json); - GOTO_IF_ERR(err, earlyexit); + reactor->Finish(status); + return reactor; + } - for (size_t idx = 0; idx < shm_status_json.ArraySize(); ++idx) { - triton::common::TritonJson::Value shm_region_json; - err = shm_status_json.IndexAsObject(idx, &shm_region_json); - GOTO_IF_ERR(err, earlyexit); + ::grpc::ServerUnaryReactor* SystemSharedMemoryRegister( + ::grpc::CallbackServerContext* context, + const inference::SystemSharedMemoryRegisterRequest* request, + inference::SystemSharedMemoryRegisterResponse* response) override + { + auto* reactor = context->DefaultReactor(); + + // (Optionally) Check client metadata for restricted access. + const std::pair<std::string, std::string>& restricted_kv = + restricted_keys_.Get(RestrictedCategory::SHARED_MEMORY); + if (!restricted_kv.first.empty()) { + const auto& metadata = context->client_metadata(); + auto it = metadata.find(restricted_kv.first); + if (it == metadata.end() || it->second != restricted_kv.second) { + reactor->Finish(::grpc::Status( + ::grpc::StatusCode::UNAVAILABLE, + "Missing or mismatched restricted header")); + return reactor; + } + } - const char* name; - size_t namelen; - err = shm_region_json.MemberAsString("name", &name, &namelen); - GOTO_IF_ERR(err, earlyexit); + // Core business logic - kept same as original + TRITONSERVER_Error* err = shm_manager_->RegisterSystemSharedMemory( + request->name(), request->key(), request->offset(), + request->byte_size()); - const char* key; - size_t keylen; - err = shm_region_json.MemberAsString("key", &key, &keylen); - GOTO_IF_ERR(err, earlyexit); + ::grpc::Status status; + GrpcStatusUtil::Create(&status, err); + TRITONSERVER_ErrorDelete(err); + reactor->Finish(status); + return reactor; + } - uint64_t offset; - err = shm_region_json.MemberAsUInt("offset", &offset); - GOTO_IF_ERR(err, earlyexit); + ::grpc::ServerUnaryReactor* SystemSharedMemoryStatus( + ::grpc::CallbackServerContext* context, + const inference::SystemSharedMemoryStatusRequest* request, + inference::SystemSharedMemoryStatusResponse* response) override + { + auto* reactor = context->DefaultReactor(); + + // (Optionally) Check client metadata for restricted access. + const std::pair<std::string, std::string>& restricted_kv = + restricted_keys_.Get(RestrictedCategory::SHARED_MEMORY); + if (!restricted_kv.first.empty()) { + const auto& metadata = context->client_metadata(); + auto it = metadata.find(restricted_kv.first); + if (it == metadata.end() || it->second != restricted_kv.second) { + reactor->Finish(::grpc::Status( + ::grpc::StatusCode::UNAVAILABLE, + "Missing or mismatched restricted header")); + return reactor; + } + } - uint64_t byte_size; - err = shm_region_json.MemberAsUInt("byte_size", &byte_size); - GOTO_IF_ERR(err, earlyexit); + // Core business logic - kept same as original + triton::common::TritonJson::Value shm_status_json( + triton::common::TritonJson::ValueType::ARRAY); + TRITONSERVER_Error* err = shm_manager_->GetStatus( + request->name(), TRITONSERVER_MEMORY_CPU, &shm_status_json); + GOTO_IF_ERR(err, earlyexit); - inference::SystemSharedMemoryStatusResponse::RegionStatus - region_status; - region_status.set_name(std::string(name, namelen)); - region_status.set_key(std::string(key, keylen)); - region_status.set_offset(offset); - region_status.set_byte_size(byte_size); + for (size_t idx = 0; idx < shm_status_json.ArraySize(); ++idx) { + triton::common::TritonJson::Value shm_region_json; + err = shm_status_json.IndexAsObject(idx, &shm_region_json); + GOTO_IF_ERR(err, earlyexit); - (*response->mutable_regions())[name] = region_status; - } + const char* name; + size_t namelen; + err = shm_region_json.MemberAsString("name", &name, &namelen); + GOTO_IF_ERR(err, earlyexit); - earlyexit: - GrpcStatusUtil::Create(status, err); - TRITONSERVER_ErrorDelete(err); - }; - - const std::pair<std::string, std::string>& restricted_kv = - restricted_keys_.Get(RestrictedCategory::SHARED_MEMORY); - new CommonCallData< - ::grpc::ServerAsyncResponseWriter< - inference::SystemSharedMemoryStatusResponse>, - inference::SystemSharedMemoryStatusRequest, - inference::SystemSharedMemoryStatusResponse>( - "SystemSharedMemoryStatus", 0, OnRegisterSystemSharedMemoryStatus, - OnExecuteSystemSharedMemoryStatus, false /* async */, cq_, restricted_kv, - response_delay_); -} + const char* key; + size_t keylen; + err = shm_region_json.MemberAsString("key", &key, &keylen); + GOTO_IF_ERR(err, earlyexit); -void -CommonHandler::RegisterSystemSharedMemoryRegister() -{ - auto OnRegisterSystemSharedMemoryRegister = - [this]( - ::grpc::ServerContext* ctx, - inference::SystemSharedMemoryRegisterRequest* request, - ::grpc::ServerAsyncResponseWriter< - inference::SystemSharedMemoryRegisterResponse>* responder, - void* tag) { - this->service_->RequestSystemSharedMemoryRegister( - ctx, request, responder, this->cq_, this->cq_, tag); - }; - - auto OnExecuteSystemSharedMemoryRegister = - [this]( - inference::SystemSharedMemoryRegisterRequest& request, - inference::SystemSharedMemoryRegisterResponse* response, - ::grpc::Status* status) { - TRITONSERVER_Error* err = shm_manager_->RegisterSystemSharedMemory( - request.name(), request.key(), request.offset(), - request.byte_size()); - - GrpcStatusUtil::Create(status, err); - TRITONSERVER_ErrorDelete(err); - }; - - const std::pair<std::string, std::string>& restricted_kv = - restricted_keys_.Get(RestrictedCategory::SHARED_MEMORY); - new CommonCallData< - ::grpc::ServerAsyncResponseWriter< - inference::SystemSharedMemoryRegisterResponse>, - inference::SystemSharedMemoryRegisterRequest, - inference::SystemSharedMemoryRegisterResponse>( - "SystemSharedMemoryRegister", 0, OnRegisterSystemSharedMemoryRegister, - OnExecuteSystemSharedMemoryRegister, false /* async */, cq_, - restricted_kv, response_delay_); -} + uint64_t offset; + err = shm_region_json.MemberAsUInt("offset", &offset); + GOTO_IF_ERR(err, earlyexit); -void -CommonHandler::RegisterSystemSharedMemoryUnregister() -{ - auto OnRegisterSystemSharedMemoryUnregister = - [this]( - ::grpc::ServerContext* ctx, - inference::SystemSharedMemoryUnregisterRequest* request, - ::grpc::ServerAsyncResponseWriter< - inference::SystemSharedMemoryUnregisterResponse>* responder, - void* tag) { - this->service_->RequestSystemSharedMemoryUnregister( - ctx, request, responder, this->cq_, this->cq_, tag); - }; - - auto OnExecuteSystemSharedMemoryUnregister = - [this]( - inference::SystemSharedMemoryUnregisterRequest& request, - inference::SystemSharedMemoryUnregisterResponse* response, - ::grpc::Status* status) { - TRITONSERVER_Error* err = nullptr; - if (request.name().empty()) { - err = shm_manager_->UnregisterAll(TRITONSERVER_MEMORY_CPU); - } else { - err = - shm_manager_->Unregister(request.name(), TRITONSERVER_MEMORY_CPU); - } + uint64_t byte_size; + err = shm_region_json.MemberAsUInt("byte_size", &byte_size); + GOTO_IF_ERR(err, earlyexit); - GrpcStatusUtil::Create(status, err); - TRITONSERVER_ErrorDelete(err); - }; - - const std::pair<std::string, std::string>& restricted_kv = - restricted_keys_.Get(RestrictedCategory::SHARED_MEMORY); - new CommonCallData< - ::grpc::ServerAsyncResponseWriter< - inference::SystemSharedMemoryUnregisterResponse>, - inference::SystemSharedMemoryUnregisterRequest, - inference::SystemSharedMemoryUnregisterResponse>( - "SystemSharedMemoryUnregister", 0, OnRegisterSystemSharedMemoryUnregister, - OnExecuteSystemSharedMemoryUnregister, false /* async */, cq_, - restricted_kv, response_delay_); -} + inference::SystemSharedMemoryStatusResponse::RegionStatus region_status; + region_status.set_name(std::string(name, namelen)); + region_status.set_key(std::string(key, keylen)); + region_status.set_offset(offset); + region_status.set_byte_size(byte_size); -void -CommonHandler::RegisterCudaSharedMemoryStatus() -{ - auto OnRegisterCudaSharedMemoryStatus = - [this]( - ::grpc::ServerContext* ctx, - inference::CudaSharedMemoryStatusRequest* request, - ::grpc::ServerAsyncResponseWriter< - inference::CudaSharedMemoryStatusResponse>* responder, - void* tag) { - this->service_->RequestCudaSharedMemoryStatus( - ctx, request, responder, this->cq_, this->cq_, tag); - }; - auto OnExecuteCudaSharedMemoryStatus = - [this]( - inference::CudaSharedMemoryStatusRequest& request, - inference::CudaSharedMemoryStatusResponse* response, - ::grpc::Status* status) { - triton::common::TritonJson::Value shm_status_json( - triton::common::TritonJson::ValueType::ARRAY); - TRITONSERVER_Error* err = shm_manager_->GetStatus( - request.name(), TRITONSERVER_MEMORY_GPU, &shm_status_json); - GOTO_IF_ERR(err, earlyexit); - - for (size_t idx = 0; idx < shm_status_json.ArraySize(); ++idx) { - triton::common::TritonJson::Value shm_region_json; - err = shm_status_json.IndexAsObject(idx, &shm_region_json); - GOTO_IF_ERR(err, earlyexit); + (*response->mutable_regions())[name] = region_status; + } - const char* name; - size_t namelen; - err = shm_region_json.MemberAsString("name", &name, &namelen); - GOTO_IF_ERR(err, earlyexit); + earlyexit: + ::grpc::Status status; + GrpcStatusUtil::Create(&status, err); + TRITONSERVER_ErrorDelete(err); + reactor->Finish(status); + return reactor; + } - uint64_t device_id; - err = shm_region_json.MemberAsUInt("device_id", &device_id); - GOTO_IF_ERR(err, earlyexit); + ::grpc::ServerUnaryReactor* CudaSharedMemoryRegister( + ::grpc::CallbackServerContext* context, + const inference::CudaSharedMemoryRegisterRequest* request, + inference::CudaSharedMemoryRegisterResponse* response) override + { + auto* reactor = context->DefaultReactor(); + + // (Optionally) Check client metadata for restricted access. + const std::pair<std::string, std::string>& restricted_kv = + restricted_keys_.Get(RestrictedCategory::SHARED_MEMORY); + if (!restricted_kv.first.empty()) { + const auto& metadata = context->client_metadata(); + auto it = metadata.find(restricted_kv.first); + if (it == metadata.end() || it->second != restricted_kv.second) { + reactor->Finish(::grpc::Status( + ::grpc::StatusCode::UNAVAILABLE, + "Missing or mismatched restricted header")); + return reactor; + } + } - uint64_t byte_size; - err = shm_region_json.MemberAsUInt("byte_size", &byte_size); - GOTO_IF_ERR(err, earlyexit); + // Core business logic + TRITONSERVER_Error* err = nullptr; +#ifdef TRITON_ENABLE_GPU + err = shm_manager_->RegisterCUDASharedMemory( + request->name(), + reinterpret_cast<const cudaIpcMemHandle_t*>( + request->raw_handle().c_str()), + request->byte_size(), request->device_id()); +#else + err = TRITONSERVER_ErrorNew( + TRITONSERVER_ERROR_INVALID_ARG, + std::string( + "failed to register CUDA shared memory region: '" + + request->name() + "', GPUs not supported") + .c_str()); +#endif // TRITON_ENABLE_GPU + ::grpc::Status status; + GrpcStatusUtil::Create(&status, err); + TRITONSERVER_ErrorDelete(err); + reactor->Finish(status); + return reactor; + } - inference::CudaSharedMemoryStatusResponse::RegionStatus region_status; - region_status.set_name(std::string(name, namelen)); - region_status.set_device_id(device_id); - region_status.set_byte_size(byte_size); + ::grpc::ServerUnaryReactor* CudaSharedMemoryStatus( + ::grpc::CallbackServerContext* context, + const inference::CudaSharedMemoryStatusRequest* request, + inference::CudaSharedMemoryStatusResponse* response) override + { + auto* reactor = context->DefaultReactor(); + + // (Optionally) Check client metadata for restricted access. + const std::pair<std::string, std::string>& restricted_kv = + restricted_keys_.Get(RestrictedCategory::SHARED_MEMORY); + if (!restricted_kv.first.empty()) { + const auto& metadata = context->client_metadata(); + auto it = metadata.find(restricted_kv.first); + if (it == metadata.end() || it->second != restricted_kv.second) { + reactor->Finish(::grpc::Status( + ::grpc::StatusCode::UNAVAILABLE, + "Missing or mismatched restricted header")); + return reactor; + } + } - (*response->mutable_regions())[name] = region_status; - } - earlyexit: - GrpcStatusUtil::Create(status, err); - TRITONSERVER_ErrorDelete(err); - }; - - const std::pair<std::string, std::string>& restricted_kv = - restricted_keys_.Get(RestrictedCategory::SHARED_MEMORY); - new CommonCallData< - ::grpc::ServerAsyncResponseWriter< - inference::CudaSharedMemoryStatusResponse>, - inference::CudaSharedMemoryStatusRequest, - inference::CudaSharedMemoryStatusResponse>( - "CudaSharedMemoryStatus", 0, OnRegisterCudaSharedMemoryStatus, - OnExecuteCudaSharedMemoryStatus, false /* async */, cq_, restricted_kv, - response_delay_); -} + // Core business logic - kept same as original + triton::common::TritonJson::Value shm_status_json( + triton::common::TritonJson::ValueType::ARRAY); + TRITONSERVER_Error* err = shm_manager_->GetStatus( + request->name(), TRITONSERVER_MEMORY_GPU, &shm_status_json); + GOTO_IF_ERR(err, earlyexit); -void -CommonHandler::RegisterCudaSharedMemoryRegister() -{ - auto OnRegisterCudaSharedMemoryRegister = - [this]( - ::grpc::ServerContext* ctx, - inference::CudaSharedMemoryRegisterRequest* request, - ::grpc::ServerAsyncResponseWriter< - inference::CudaSharedMemoryRegisterResponse>* responder, - void* tag) { - this->service_->RequestCudaSharedMemoryRegister( - ctx, request, responder, this->cq_, this->cq_, tag); - }; - - auto OnExecuteCudaSharedMemoryRegister = - [this]( - inference::CudaSharedMemoryRegisterRequest& request, - inference::CudaSharedMemoryRegisterResponse* response, - ::grpc::Status* status) { - TRITONSERVER_Error* err = nullptr; -#ifdef TRITON_ENABLE_GPU - err = shm_manager_->RegisterCUDASharedMemory( - request.name(), - reinterpret_cast<const cudaIpcMemHandle_t*>( - request.raw_handle().c_str()), - request.byte_size(), request.device_id()); -#else - err = TRITONSERVER_ErrorNew( - TRITONSERVER_ERROR_INVALID_ARG, - std::string( - "failed to register CUDA shared memory region: '" + - request.name() + "', GPUs not supported") - .c_str()); -#endif // TRITON_ENABLE_GPU + for (size_t idx = 0; idx < shm_status_json.ArraySize(); ++idx) { + triton::common::TritonJson::Value shm_region_json; + err = shm_status_json.IndexAsObject(idx, &shm_region_json); + GOTO_IF_ERR(err, earlyexit); - GrpcStatusUtil::Create(status, err); - TRITONSERVER_ErrorDelete(err); - }; - - const std::pair<std::string, std::string>& restricted_kv = - restricted_keys_.Get(RestrictedCategory::SHARED_MEMORY); - new CommonCallData< - ::grpc::ServerAsyncResponseWriter< - inference::CudaSharedMemoryRegisterResponse>, - inference::CudaSharedMemoryRegisterRequest, - inference::CudaSharedMemoryRegisterResponse>( - "CudaSharedMemoryRegister", 0, OnRegisterCudaSharedMemoryRegister, - OnExecuteCudaSharedMemoryRegister, false /* async */, cq_, restricted_kv, - response_delay_); -} + const char* name; + size_t namelen; + err = shm_region_json.MemberAsString("name", &name, &namelen); + GOTO_IF_ERR(err, earlyexit); -void -CommonHandler::RegisterCudaSharedMemoryUnregister() -{ - auto OnRegisterCudaSharedMemoryUnregister = - [this]( - ::grpc::ServerContext* ctx, - inference::CudaSharedMemoryUnregisterRequest* request, - ::grpc::ServerAsyncResponseWriter< - inference::CudaSharedMemoryUnregisterResponse>* responder, - void* tag) { - this->service_->RequestCudaSharedMemoryUnregister( - ctx, request, responder, this->cq_, this->cq_, tag); - }; - - auto OnExecuteCudaSharedMemoryUnregister = - [this]( - inference::CudaSharedMemoryUnregisterRequest& request, - inference::CudaSharedMemoryUnregisterResponse* response, - ::grpc::Status* status) { - TRITONSERVER_Error* err = nullptr; - if (request.name().empty()) { - err = shm_manager_->UnregisterAll(TRITONSERVER_MEMORY_GPU); - } else { - err = - shm_manager_->Unregister(request.name(), TRITONSERVER_MEMORY_GPU); - } + uint64_t device_id; + err = shm_region_json.MemberAsUInt("device_id", &device_id); + GOTO_IF_ERR(err, earlyexit); - GrpcStatusUtil::Create(status, err); - TRITONSERVER_ErrorDelete(err); - }; - const std::pair<std::string, std::string>& restricted_kv = - restricted_keys_.Get(RestrictedCategory::SHARED_MEMORY); - - new CommonCallData< - ::grpc::ServerAsyncResponseWriter< - inference::CudaSharedMemoryUnregisterResponse>, - inference::CudaSharedMemoryUnregisterRequest, - inference::CudaSharedMemoryUnregisterResponse>( - "CudaSharedMemoryUnregister", 0, OnRegisterCudaSharedMemoryUnregister, - OnExecuteCudaSharedMemoryUnregister, false /* async */, cq_, - restricted_kv, response_delay_); -} + uint64_t byte_size; + err = shm_region_json.MemberAsUInt("byte_size", &byte_size); + GOTO_IF_ERR(err, earlyexit); -void -CommonHandler::RegisterRepositoryIndex() -{ - auto OnRegisterRepositoryIndex = - [this]( - ::grpc::ServerContext* ctx, - inference::RepositoryIndexRequest* request, - ::grpc::ServerAsyncResponseWriter<inference::RepositoryIndexResponse>* - responder, - void* tag) { - this->service_->RequestRepositoryIndex( - ctx, request, responder, this->cq_, this->cq_, tag); - }; - - auto OnExecuteRepositoryIndex = - [this]( - inference::RepositoryIndexRequest& request, - inference::RepositoryIndexResponse* response, - ::grpc::Status* status) { - TRITONSERVER_Error* err = nullptr; - if (request.repository_name().empty()) { - uint32_t flags = 0; - if (request.ready()) { - flags |= TRITONSERVER_INDEX_FLAG_READY; - } + inference::CudaSharedMemoryStatusResponse::RegionStatus region_status; + region_status.set_name(std::string(name, namelen)); + region_status.set_device_id(device_id); + region_status.set_byte_size(byte_size); - TRITONSERVER_Message* model_index_message = nullptr; - err = TRITONSERVER_ServerModelIndex( - tritonserver_.get(), flags, &model_index_message); - GOTO_IF_ERR(err, earlyexit); + (*response->mutable_regions())[name] = region_status; + } - const char* buffer; - size_t byte_size; - err = TRITONSERVER_MessageSerializeToJson( - model_index_message, &buffer, &byte_size); - GOTO_IF_ERR(err, earlyexit); + earlyexit: + ::grpc::Status status; + GrpcStatusUtil::Create(&status, err); + TRITONSERVER_ErrorDelete(err); + reactor->Finish(status); + return reactor; + } - triton::common::TritonJson::Value model_index_json; - err = model_index_json.Parse(buffer, byte_size); - GOTO_IF_ERR(err, earlyexit); + ::grpc::ServerUnaryReactor* SystemSharedMemoryUnregister( + ::grpc::CallbackServerContext* context, + const inference::SystemSharedMemoryUnregisterRequest* request, + inference::SystemSharedMemoryUnregisterResponse* response) override + { + auto* reactor = context->DefaultReactor(); + + // (Optionally) Check client metadata for restricted access. + const std::pair<std::string, std::string>& restricted_kv = + restricted_keys_.Get(RestrictedCategory::SHARED_MEMORY); + if (!restricted_kv.first.empty()) { + const auto& metadata = context->client_metadata(); + auto it = metadata.find(restricted_kv.first); + if (it == metadata.end() || it->second != restricted_kv.second) { + reactor->Finish(::grpc::Status( + ::grpc::StatusCode::UNAVAILABLE, + "Missing or mismatched restricted header")); + return reactor; + } + } - err = model_index_json.AssertType( - triton::common::TritonJson::ValueType::ARRAY); - GOTO_IF_ERR(err, earlyexit); + // Core business logic - kept same as original + TRITONSERVER_Error* err = nullptr; + if (request->name().empty()) { + err = shm_manager_->UnregisterAll(TRITONSERVER_MEMORY_CPU); + } else { + err = shm_manager_->Unregister(request->name(), TRITONSERVER_MEMORY_CPU); + } - for (size_t idx = 0; idx < model_index_json.ArraySize(); ++idx) { - triton::common::TritonJson::Value index_json; - err = model_index_json.IndexAsObject(idx, &index_json); - GOTO_IF_ERR(err, earlyexit); + ::grpc::Status status; + GrpcStatusUtil::Create(&status, err); + TRITONSERVER_ErrorDelete(err); + reactor->Finish(status); + return reactor; + } - auto model_index = response->add_models(); + // Add here + ::grpc::ServerUnaryReactor* CudaSharedMemoryUnregister( + ::grpc::CallbackServerContext* context, + const inference::CudaSharedMemoryUnregisterRequest* request, + inference::CudaSharedMemoryUnregisterResponse* response) override + { + auto* reactor = context->DefaultReactor(); + + // (Optionally) Check client metadata for restricted access. + const std::pair<std::string, std::string>& restricted_kv = + restricted_keys_.Get(RestrictedCategory::SHARED_MEMORY); + if (!restricted_kv.first.empty()) { + const auto& metadata = context->client_metadata(); + auto it = metadata.find(restricted_kv.first); + if (it == metadata.end() || it->second != restricted_kv.second) { + reactor->Finish(::grpc::Status( + ::grpc::StatusCode::UNAVAILABLE, + "Missing or mismatched restricted header")); + return reactor; + } + } - const char* name; - size_t namelen; - err = index_json.MemberAsString("name", &name, &namelen); - GOTO_IF_ERR(err, earlyexit); - model_index->set_name(std::string(name, namelen)); + // Core business logic - kept same as original + TRITONSERVER_Error* err = nullptr; + if (request->name().empty()) { + err = shm_manager_->UnregisterAll(TRITONSERVER_MEMORY_GPU); + } else { + err = shm_manager_->Unregister(request->name(), TRITONSERVER_MEMORY_GPU); + } - if (index_json.Find("version")) { - const char* version; - size_t versionlen; - err = index_json.MemberAsString("version", &version, &versionlen); - GOTO_IF_ERR(err, earlyexit); - model_index->set_version(std::string(version, versionlen)); - } - if (index_json.Find("state")) { - const char* state; - size_t statelen; - err = index_json.MemberAsString("state", &state, &statelen); - GOTO_IF_ERR(err, earlyexit); - model_index->set_state(std::string(state, statelen)); - } - if (index_json.Find("reason")) { - const char* reason; - size_t reasonlen; - err = index_json.MemberAsString("reason", &reason, &reasonlen); - GOTO_IF_ERR(err, earlyexit); - model_index->set_reason(std::string(reason, reasonlen)); - } - } + ::grpc::Status status; + GrpcStatusUtil::Create(&status, err); + TRITONSERVER_ErrorDelete(err); + reactor->Finish(status); + return reactor; + } - TRITONSERVER_MessageDelete(model_index_message); - } else { - err = TRITONSERVER_ErrorNew( - TRITONSERVER_ERROR_UNSUPPORTED, - "'repository_name' specification is not supported"); - } + ::grpc::ServerUnaryReactor* RepositoryIndex( + ::grpc::CallbackServerContext* context, + const inference::RepositoryIndexRequest* request, + inference::RepositoryIndexResponse* response) override + { + auto* reactor = context->DefaultReactor(); + + // (Optionally) Check client metadata for restricted access. + const std::pair<std::string, std::string>& restricted_kv = + restricted_keys_.Get(RestrictedCategory::MODEL_REPOSITORY); + if (!restricted_kv.first.empty()) { + const auto& metadata = context->client_metadata(); + auto it = metadata.find(restricted_kv.first); + if (it == metadata.end() || it->second != restricted_kv.second) { + reactor->Finish(::grpc::Status( + ::grpc::StatusCode::UNAVAILABLE, + "Missing or mismatched restricted header")); + return reactor; + } + } - earlyexit: - GrpcStatusUtil::Create(status, err); - TRITONSERVER_ErrorDelete(err); - }; - - const std::pair<std::string, std::string>& restricted_kv = - restricted_keys_.Get(RestrictedCategory::MODEL_REPOSITORY); - new CommonCallData< - ::grpc::ServerAsyncResponseWriter<inference::RepositoryIndexResponse>, - inference::RepositoryIndexRequest, inference::RepositoryIndexResponse>( - "RepositoryIndex", 0, OnRegisterRepositoryIndex, OnExecuteRepositoryIndex, - false /* async */, cq_, restricted_kv, response_delay_); -} + // Core business logic + TRITONSERVER_Error* err = nullptr; + if (request->repository_name().empty()) { + uint32_t flags = 0; + if (request->ready()) { + flags |= TRITONSERVER_INDEX_FLAG_READY; + } -void -CommonHandler::RegisterRepositoryModelLoad() -{ - auto OnRegisterRepositoryModelLoad = - [this]( - ::grpc::ServerContext* ctx, - inference::RepositoryModelLoadRequest* request, - ::grpc::ServerAsyncResponseWriter< - inference::RepositoryModelLoadResponse>* responder, - void* tag) { - this->service_->RequestRepositoryModelLoad( - ctx, request, responder, this->cq_, this->cq_, tag); - }; - - auto OnExecuteRepositoryModelLoad = - [this]( - inference::RepositoryModelLoadRequest& request, - inference::RepositoryModelLoadResponse* response, - ::grpc::Status* status) { - TRITONSERVER_Error* err = nullptr; - if (request.repository_name().empty()) { - std::vector<TRITONSERVER_Parameter*> params; - // WAR for the const-ness check - std::vector<const TRITONSERVER_Parameter*> const_params; - for (const auto& param_proto : request.parameters()) { - if (param_proto.first == "config") { - if (param_proto.second.parameter_choice_case() != - inference::ModelRepositoryParameter::ParameterChoiceCase:: - kStringParam) { - err = TRITONSERVER_ErrorNew( - TRITONSERVER_ERROR_INVALID_ARG, - (std::string("invalid value type for load parameter '") + - param_proto.first + "', expected string_param.") - .c_str()); - break; - } else { - auto param = TRITONSERVER_ParameterNew( - param_proto.first.c_str(), TRITONSERVER_PARAMETER_STRING, - param_proto.second.string_param().c_str()); - if (param != nullptr) { - params.emplace_back(param); - const_params.emplace_back(param); - } else { - err = TRITONSERVER_ErrorNew( - TRITONSERVER_ERROR_INTERNAL, - "unexpected error on creating Triton parameter"); + TRITONSERVER_Message* model_index_message = nullptr; + err = TRITONSERVER_ServerModelIndex( + tritonserver_.get(), flags, &model_index_message); + if (err == nullptr) { + const char* buffer; + size_t byte_size; + err = TRITONSERVER_MessageSerializeToJson( + model_index_message, &buffer, &byte_size); + if (err == nullptr) { + triton::common::TritonJson::Value model_index_json; + err = model_index_json.Parse(buffer, byte_size); + if (err == nullptr) { + err = model_index_json.AssertType( + triton::common::TritonJson::ValueType::ARRAY); + if (err == nullptr) { + for (size_t idx = 0; idx < model_index_json.ArraySize(); ++idx) { + triton::common::TritonJson::Value index_json; + err = model_index_json.IndexAsObject(idx, &index_json); + if (err != nullptr) { break; } - } - } else if (param_proto.first.rfind("file:", 0) == 0) { - if (param_proto.second.parameter_choice_case() != - inference::ModelRepositoryParameter::ParameterChoiceCase:: - kBytesParam) { - err = TRITONSERVER_ErrorNew( - TRITONSERVER_ERROR_INVALID_ARG, - (std::string("invalid value type for load parameter '") + - param_proto.first + "', expected bytes_param.") - .c_str()); - break; - } else { - auto param = TRITONSERVER_ParameterBytesNew( - param_proto.first.c_str(), - param_proto.second.bytes_param().data(), - param_proto.second.bytes_param().length()); - if (param != nullptr) { - params.emplace_back(param); - const_params.emplace_back(param); - } else { - err = TRITONSERVER_ErrorNew( - TRITONSERVER_ERROR_INTERNAL, - "unexpected error on creating Triton parameter"); + + auto model_index = response->add_models(); + + const char* name; + size_t namelen; + err = index_json.MemberAsString("name", &name, &namelen); + if (err != nullptr) { break; } + model_index->set_name(std::string(name, namelen)); + + if (index_json.Find("version")) { + const char* version; + size_t versionlen; + err = index_json.MemberAsString( + "version", &version, &versionlen); + if (err != nullptr) { + break; + } + model_index->set_version(std::string(version, versionlen)); + } + if (index_json.Find("state")) { + const char* state; + size_t statelen; + err = index_json.MemberAsString("state", &state, &statelen); + if (err != nullptr) { + break; + } + model_index->set_state(std::string(state, statelen)); + } + if (index_json.Find("reason")) { + const char* reason; + size_t reasonlen; + err = + index_json.MemberAsString("reason", &reason, &reasonlen); + if (err != nullptr) { + break; + } + model_index->set_reason(std::string(reason, reasonlen)); + } } - } else { - err = TRITONSERVER_ErrorNew( - TRITONSERVER_ERROR_INVALID_ARG, - (std::string("unrecognized load parameter '") + - param_proto.first + "'.") - .c_str()); - break; } } - if (err == nullptr) { - err = TRITONSERVER_ServerLoadModelWithParameters( - tritonserver_.get(), request.model_name().c_str(), - const_params.data(), const_params.size()); - } - // Assumes no further 'params' access after load API returns - for (auto& param : params) { - TRITONSERVER_ParameterDelete(param); - } - } else { - err = TRITONSERVER_ErrorNew( - TRITONSERVER_ERROR_UNSUPPORTED, - "'repository_name' specification is not supported"); } + TRITONSERVER_MessageDelete(model_index_message); + } + } else { + err = TRITONSERVER_ErrorNew( + TRITONSERVER_ERROR_UNSUPPORTED, + "'repository_name' specification is not supported"); + } - GrpcStatusUtil::Create(status, err); - TRITONSERVER_ErrorDelete(err); - }; - - const std::pair<std::string, std::string>& restricted_kv = - restricted_keys_.Get(RestrictedCategory::MODEL_REPOSITORY); - new CommonCallData< - ::grpc::ServerAsyncResponseWriter<inference::RepositoryModelLoadResponse>, - inference::RepositoryModelLoadRequest, - inference::RepositoryModelLoadResponse>( - "RepositoryModelLoad", 0, OnRegisterRepositoryModelLoad, - OnExecuteRepositoryModelLoad, true /* async */, cq_, restricted_kv, - response_delay_); -} + ::grpc::Status status; + GrpcStatusUtil::Create(&status, err); + TRITONSERVER_ErrorDelete(err); + reactor->Finish(status); + return reactor; + } -void -CommonHandler::RegisterRepositoryModelUnload() -{ - auto OnRegisterRepositoryModelUnload = - [this]( - ::grpc::ServerContext* ctx, - inference::RepositoryModelUnloadRequest* request, - ::grpc::ServerAsyncResponseWriter< - inference::RepositoryModelUnloadResponse>* responder, - void* tag) { - this->service_->RequestRepositoryModelUnload( - ctx, request, responder, this->cq_, this->cq_, tag); - }; - - auto OnExecuteRepositoryModelUnload = - [this]( - inference::RepositoryModelUnloadRequest& request, - inference::RepositoryModelUnloadResponse* response, - ::grpc::Status* status) { - TRITONSERVER_Error* err = nullptr; - if (request.repository_name().empty()) { - // Check if the dependent models should be removed - bool unload_dependents = false; - for (auto param : request.parameters()) { - if (param.first.compare("unload_dependents") == 0) { - const auto& unload_param = param.second; - if (unload_param.parameter_choice_case() != - inference::ModelRepositoryParameter::ParameterChoiceCase:: - kBoolParam) { - err = TRITONSERVER_ErrorNew( - TRITONSERVER_ERROR_INVALID_ARG, - "invalid value type for 'unload_dependents' parameter, " - "expected " - "bool_param."); - } - unload_dependents = unload_param.bool_param(); + ::grpc::ServerUnaryReactor* RepositoryModelLoad( + ::grpc::CallbackServerContext* context, + const inference::RepositoryModelLoadRequest* request, + inference::RepositoryModelLoadResponse* response) override + { + auto* reactor = context->DefaultReactor(); + + // (Optionally) Check client metadata for restricted access. + const std::pair<std::string, std::string>& restricted_kv = + restricted_keys_.Get(RestrictedCategory::MODEL_REPOSITORY); + if (!restricted_kv.first.empty()) { + const auto& metadata = context->client_metadata(); + auto it = metadata.find(restricted_kv.first); + if (it == metadata.end() || it->second != restricted_kv.second) { + reactor->Finish(::grpc::Status( + ::grpc::StatusCode::UNAVAILABLE, + "Missing or mismatched restricted header")); + return reactor; + } + } + + // Core business logic + TRITONSERVER_Error* err = nullptr; + if (request->repository_name().empty()) { + std::vector<TRITONSERVER_Parameter*> params; + // WAR for the const-ness check + std::vector<const TRITONSERVER_Parameter*> const_params; + + for (const auto& param_proto : request->parameters()) { + if (param_proto.first == "config") { + if (param_proto.second.parameter_choice_case() != + inference::ModelRepositoryParameter::ParameterChoiceCase:: + kStringParam) { + err = TRITONSERVER_ErrorNew( + TRITONSERVER_ERROR_INVALID_ARG, + (std::string("invalid value type for load parameter '") + + param_proto.first + "', expected string_param.") + .c_str()); + break; + } else { + auto param = TRITONSERVER_ParameterNew( + param_proto.first.c_str(), TRITONSERVER_PARAMETER_STRING, + param_proto.second.string_param().c_str()); + if (param != nullptr) { + params.emplace_back(param); + const_params.emplace_back(param); + } else { + err = TRITONSERVER_ErrorNew( + TRITONSERVER_ERROR_INTERNAL, + "unexpected error on creating Triton parameter"); break; } } - if (err == nullptr) { - if (unload_dependents) { - err = TRITONSERVER_ServerUnloadModelAndDependents( - tritonserver_.get(), request.model_name().c_str()); + } else if (param_proto.first.rfind("file:", 0) == 0) { + if (param_proto.second.parameter_choice_case() != + inference::ModelRepositoryParameter::ParameterChoiceCase:: + kBytesParam) { + err = TRITONSERVER_ErrorNew( + TRITONSERVER_ERROR_INVALID_ARG, + (std::string("invalid value type for load parameter '") + + param_proto.first + "', expected bytes_param.") + .c_str()); + break; + } else { + auto param = TRITONSERVER_ParameterBytesNew( + param_proto.first.c_str(), + param_proto.second.bytes_param().data(), + param_proto.second.bytes_param().length()); + if (param != nullptr) { + params.emplace_back(param); + const_params.emplace_back(param); } else { - err = TRITONSERVER_ServerUnloadModel( - tritonserver_.get(), request.model_name().c_str()); + err = TRITONSERVER_ErrorNew( + TRITONSERVER_ERROR_INTERNAL, + "unexpected error on creating Triton parameter"); + break; } } } else { err = TRITONSERVER_ErrorNew( - TRITONSERVER_ERROR_UNSUPPORTED, - "'repository_name' specification is not supported"); + TRITONSERVER_ERROR_INVALID_ARG, + (std::string("unrecognized load parameter '") + + param_proto.first + "'.") + .c_str()); + break; + } + } + + if (err == nullptr) { + err = TRITONSERVER_ServerLoadModelWithParameters( + tritonserver_.get(), request->model_name().c_str(), + const_params.data(), const_params.size()); + } + + // Assumes no further 'params' access after load API returns + for (auto& param : params) { + TRITONSERVER_ParameterDelete(param); + } + } else { + err = TRITONSERVER_ErrorNew( + TRITONSERVER_ERROR_UNSUPPORTED, + "'repository_name' specification is not supported"); + } + + ::grpc::Status status; + GrpcStatusUtil::Create(&status, err); + TRITONSERVER_ErrorDelete(err); + reactor->Finish(status); + return reactor; + } + + ::grpc::ServerUnaryReactor* RepositoryModelUnload( + ::grpc::CallbackServerContext* context, + const inference::RepositoryModelUnloadRequest* request, + inference::RepositoryModelUnloadResponse* response) override + { + auto* reactor = context->DefaultReactor(); + + // (Optionally) Check client metadata for restricted access. + const std::pair<std::string, std::string>& restricted_kv = + restricted_keys_.Get(RestrictedCategory::MODEL_REPOSITORY); + if (!restricted_kv.first.empty()) { + const auto& metadata = context->client_metadata(); + auto it = metadata.find(restricted_kv.first); + if (it == metadata.end() || it->second != restricted_kv.second) { + reactor->Finish(::grpc::Status( + ::grpc::StatusCode::UNAVAILABLE, + "Missing or mismatched restricted header")); + return reactor; + } + } + + // Core business logic + TRITONSERVER_Error* err = nullptr; + if (request->repository_name().empty()) { + // Check if the dependent models should be removed + bool unload_dependents = false; + for (const auto& param : request->parameters()) { + if (param.first.compare("unload_dependents") == 0) { + const auto& unload_param = param.second; + if (unload_param.parameter_choice_case() != + inference::ModelRepositoryParameter::ParameterChoiceCase:: + kBoolParam) { + err = TRITONSERVER_ErrorNew( + TRITONSERVER_ERROR_INVALID_ARG, + "invalid value type for 'unload_dependents' parameter, " + "expected bool_param."); + } + unload_dependents = unload_param.bool_param(); + break; + } + } + + if (err == nullptr) { + if (unload_dependents) { + err = TRITONSERVER_ServerUnloadModelAndDependents( + tritonserver_.get(), request->model_name().c_str()); + } else { + err = TRITONSERVER_ServerUnloadModel( + tritonserver_.get(), request->model_name().c_str()); } + } + } else { + err = TRITONSERVER_ErrorNew( + TRITONSERVER_ERROR_UNSUPPORTED, + "'repository_name' specification is not supported"); + } + + ::grpc::Status status; + GrpcStatusUtil::Create(&status, err); + TRITONSERVER_ErrorDelete(err); + reactor->Finish(status); + return reactor; + } - GrpcStatusUtil::Create(status, err); - TRITONSERVER_ErrorDelete(err); - }; - - const std::pair<std::string, std::string>& restricted_kv = - restricted_keys_.Get(RestrictedCategory::MODEL_REPOSITORY); - new CommonCallData< - ::grpc::ServerAsyncResponseWriter< - inference::RepositoryModelUnloadResponse>, - inference::RepositoryModelUnloadRequest, - inference::RepositoryModelUnloadResponse>( - "RepositoryModelUnload", 0, OnRegisterRepositoryModelUnload, - OnExecuteRepositoryModelUnload, true /* async */, cq_, restricted_kv, - response_delay_); + private: + std::shared_ptr<TRITONSERVER_Server> tritonserver_; + std::shared_ptr<SharedMemoryManager> shm_manager_; + TraceManager* trace_manager_; + RestrictedFeatures restricted_keys_; + ModelInferCallbackHandler model_infer_handler_; +}; + +// +// CommonHandler +// +// A common handler for all non-inference requests. +// +class CommonHandler : public HandlerBase { + public: + CommonHandler( + const std::string& name, + const std::shared_ptr<TRITONSERVER_Server>& tritonserver, + const std::shared_ptr<SharedMemoryManager>& shm_manager, + TraceManager* trace_manager, + inference::GRPCInferenceService::AsyncService* service, + ::grpc::health::v1::Health::AsyncService* health_service, + inference::GRPCInferenceServiceCallback::CallbackService* + non_inference_callback_service, + const RestrictedFeatures& restricted_keys, const uint64_t response_delay); + + // Implement pure virtual functions + void Start() override {} // No-op for callback implementation + void Stop() override {} // No-op for callback implementation + + // Descriptive name of of the handler. + const std::string& Name() const { return name_; } + + void CreateCallbackServices(); + + // Add methods to return the callback services + inference::GRPCInferenceServiceCallback::CallbackService* + GetUnifiedCallbackService() + { + return non_inference_callback_service_; + } + + ::grpc::health::v1::Health::CallbackService* GetHealthCallbackService() + { + return health_callback_service_; + } + + private: + const std::string name_; + std::shared_ptr<TRITONSERVER_Server> tritonserver_; + std::shared_ptr<SharedMemoryManager> shm_manager_; + TraceManager* trace_manager_; + inference::GRPCInferenceService::AsyncService* service_; + ::grpc::health::v1::Health::AsyncService* health_service_; + inference::GRPCInferenceServiceCallback::CallbackService* + non_inference_callback_service_; + ::grpc::health::v1::Health::CallbackService* health_callback_service_; + std::unique_ptr<std::thread> thread_; + RestrictedFeatures restricted_keys_; + const uint64_t response_delay_; +}; + +CommonHandler::CommonHandler( + const std::string& name, + const std::shared_ptr<TRITONSERVER_Server>& tritonserver, + const std::shared_ptr<SharedMemoryManager>& shm_manager, + TraceManager* trace_manager, + inference::GRPCInferenceService::AsyncService* service, + ::grpc::health::v1::Health::AsyncService* health_service, + inference::GRPCInferenceServiceCallback::CallbackService* + non_inference_callback_service, + const RestrictedFeatures& restricted_keys, const uint64_t response_delay) + : name_(name), tritonserver_(tritonserver), shm_manager_(shm_manager), + trace_manager_(trace_manager), service_(service), + health_service_(health_service), + non_inference_callback_service_(non_inference_callback_service), + health_callback_service_(nullptr), restricted_keys_(restricted_keys), + response_delay_(response_delay) +{ + CreateCallbackServices(); +} + +void +CommonHandler::CreateCallbackServices() +{ + // Create the unified callback service for non-inference operations + // Pass all required arguments to the UnifiedCallbackService constructor + non_inference_callback_service_ = new UnifiedCallbackService( + "CommonHandler", tritonserver_, + trace_manager_, // Pass the trace manager from CommonHandler + shm_manager_, + grpc_compression_level::GRPC_COMPRESS_LEVEL_NONE, // Provide a default + // compression level + restricted_keys_, + "" // Provide an empty default for forward_header_pattern + ); + // Create the health callback service + health_callback_service_ = + new HealthCallbackService(tritonserver_, restricted_keys_); } } // namespace @@ -2282,7 +2032,6 @@ Server::Server( builder_.AddListeningPort(server_addr_, credentials, &bound_port_); builder_.SetMaxMessageSize(MAX_GRPC_MESSAGE_SIZE); builder_.RegisterService(&service_); - builder_.RegisterService(&health_service_); builder_.AddChannelArgument( GRPC_ARG_ALLOW_REUSEPORT, options.socket_.reuse_port_); @@ -2368,7 +2117,6 @@ Server::Server( LOG_TABLE_VERBOSE(1, table_printer); } - common_cq_ = builder_.AddCompletionQueue(); model_infer_cq_ = builder_.AddCompletionQueue(); model_stream_infer_cq_ = builder_.AddCompletionQueue(); @@ -2381,8 +2129,15 @@ Server::Server( // A common Handler for other non-inference requests common_handler_.reset(new CommonHandler( "CommonHandler", tritonserver_, shm_manager_, trace_manager_, &service_, - &health_service_, common_cq_.get(), options.restricted_protocols_, - response_delay)); + &health_service_, nullptr /* non_inference_callback_service */, + options.restricted_protocols_, response_delay)); + // Use common_handler_ and register services + auto* handler = dynamic_cast<CommonHandler*>(common_handler_.get()); + if (handler != nullptr) { + // Register both the unified service and health service + builder_.RegisterService(handler->GetUnifiedCallbackService()); + builder_.RegisterService(handler->GetHealthCallbackService()); + } // [FIXME] "register" logic is different for infer // Handler for model inference requests. @@ -2548,7 +2303,7 @@ Server::Start() (std::string("Socket '") + server_addr_ + "' already in use ").c_str()); } - common_handler_->Start(); + // Remove this for (auto& model_infer_handler : model_infer_handlers_) { model_infer_handler->Start(); } @@ -2572,13 +2327,13 @@ Server::Stop() // Always shutdown the completion queue after the server. server_->Shutdown(); - common_cq_->Shutdown(); + // common_cq_->Shutdown(); model_infer_cq_->Shutdown(); model_stream_infer_cq_->Shutdown(); // Must stop all handlers explicitly to wait for all the handler // threads to join since they are referencing completion queue, etc. - common_handler_->Stop(); + // common_handler_->Stop(); for (auto& model_infer_handler : model_infer_handlers_) { model_infer_handler->Stop(); } diff --git a/src/grpc/grpc_server.h b/src/grpc/grpc_server.h index a4fc358df0..d4c91265c7 100644 --- a/src/grpc/grpc_server.h +++ b/src/grpc/grpc_server.h @@ -36,6 +36,7 @@ #include "grpc_handler.h" #include "grpc_service.grpc.pb.h" #include "grpc_utils.h" +#include "grpccallback_service.grpc.pb.h" #include "health.grpc.pb.h" #include "infer_handler.h" #include "stream_infer_handler.h" @@ -143,10 +144,13 @@ class Server { inference::GRPCInferenceService::AsyncService service_; ::grpc::health::v1::Health::AsyncService health_service_; + inference::GRPCInferenceServiceCallback::CallbackService + non_inference_callback_service_; + ::grpc::health::v1::Health::CallbackService* health_callback_service_; std::unique_ptr<::grpc::Server> server_; - std::unique_ptr<::grpc::ServerCompletionQueue> common_cq_; + // std::unique_ptr<::grpc::ServerCompletionQueue> common_cq_; std::unique_ptr<::grpc::ServerCompletionQueue> model_infer_cq_; std::unique_ptr<::grpc::ServerCompletionQueue> model_stream_infer_cq_; diff --git a/src/grpc/infer_handler.cc b/src/grpc/infer_handler.cc index 0bc1e9ce0b..e44413fa39 100644 --- a/src/grpc/infer_handler.cc +++ b/src/grpc/infer_handler.cc @@ -113,6 +113,29 @@ InferResponseAlloc( buffer_userp, actual_memory_type, actual_memory_type_id); } +// Make sure to keep InferResponseAllocCallback and OutputBufferQuery logic in +// sync +TRITONSERVER_Error* +InferResponseAllocCallback( + TRITONSERVER_ResponseAllocator* allocator, const char* tensor_name, + size_t byte_size, TRITONSERVER_MemoryType preferred_memory_type, + int64_t preferred_memory_type_id, void* userp, void** buffer, + void** buffer_userp, TRITONSERVER_MemoryType* actual_memory_type, + int64_t* actual_memory_type_id) +{ + AllocPayloadCallback<inference::ModelInferResponse>* payload = + reinterpret_cast<AllocPayloadCallback<inference::ModelInferResponse>*>( + userp); + + // ModelInfer RPC expects exactly one response per request. Hence, + // Get pointer directly from the modified payload instead of the queue. + inference::ModelInferResponse* response = payload->response_ptr_; + return ResponseAllocatorHelper( + allocator, tensor_name, byte_size, preferred_memory_type, + preferred_memory_type_id, response, payload->shm_map_, buffer, + buffer_userp, actual_memory_type, actual_memory_type_id); +} + // Make sure to keep InferResponseAlloc and OutputBufferQuery logic in sync TRITONSERVER_Error* OutputBufferQuery( @@ -120,8 +143,9 @@ OutputBufferQuery( const char* tensor_name, size_t* byte_size, TRITONSERVER_MemoryType* memory_type, int64_t* memory_type_id) { - AllocPayload<inference::ModelInferResponse>* payload = - reinterpret_cast<AllocPayload<inference::ModelInferResponse>*>(userp); + AllocPayloadCallback<inference::ModelInferResponse>* payload = + reinterpret_cast<AllocPayloadCallback<inference::ModelInferResponse>*>( + userp); return OutputBufferQueryHelper( allocator, tensor_name, byte_size, payload->shm_map_, memory_type, @@ -136,8 +160,9 @@ OutputBufferAttributes( TRITONSERVER_BufferAttributes* buffer_attributes, void* userp, void* buffer_userp) { - AllocPayload<inference::ModelInferResponse>* payload = - reinterpret_cast<AllocPayload<inference::ModelInferResponse>*>(userp); + AllocPayloadCallback<inference::ModelInferResponse>* payload = + reinterpret_cast<AllocPayloadCallback<inference::ModelInferResponse>*>( + userp); return OutputBufferAttributesHelper( allocator, tensor_name, payload->shm_map_, buffer_attributes); @@ -191,12 +216,12 @@ InferGRPCToInputHelper( TRITONSERVER_Error* InferResponseStart(TRITONSERVER_ResponseAllocator* allocator, void* userp) { - AllocPayload<inference::ModelInferResponse>* payload = - reinterpret_cast<AllocPayload<inference::ModelInferResponse>*>(userp); + // AllocPayload<inference::ModelInferResponse>* payload = + // reinterpret_cast<AllocPayload<inference::ModelInferResponse>*>(userp); // ModelInfer RPC expects exactly one response per request. Hence, always call // GetNonDecoupledResponse() to create one response object on response start. - payload->response_queue_->GetNonDecoupledResponse(); + // payload->response_queue_->GetNonDecoupledResponse(); return nullptr; // success } @@ -639,7 +664,7 @@ void InferRequestComplete( TRITONSERVER_InferenceRequest* request, const uint32_t flags, void* userp) { - LOG_VERBOSE(1) << "ModelInferHandler::InferRequestComplete"; + LOG_VERBOSE(1) << "ModelInferHandler::InferRequestComplete!"; RequestReleasePayload* request_release_payload = static_cast<RequestReleasePayload*>(userp); @@ -649,6 +674,406 @@ InferRequestComplete( } } +ModelInferCallbackHandler::ModelInferCallbackHandler( + const std::string& name, + const std::shared_ptr<TRITONSERVER_Server>& tritonserver, + TraceManager* trace_manager, + const std::shared_ptr<SharedMemoryManager>& shm_manager, + grpc_compression_level compression_level, + RestrictedFeatures& restricted_keys, + const std::string& forward_header_pattern) + : name_(name), tritonserver_(tritonserver), trace_manager_(trace_manager), + shm_manager_(shm_manager), compression_level_(compression_level), + restricted_kv_(restricted_keys.Get(RestrictedCategory::INFERENCE)), + header_forward_pattern_(forward_header_pattern), + header_forward_regex_(forward_header_pattern) +{ + FAIL_IF_ERR( + TRITONSERVER_ResponseAllocatorNew( + &allocator_, InferResponseAllocCallback, InferResponseFree, + InferResponseStart), + "creating inference response allocator"); + FAIL_IF_ERR( + TRITONSERVER_ResponseAllocatorSetQueryFunction( + allocator_, OutputBufferQuery), + "setting allocator's query function"); + FAIL_IF_ERR( + TRITONSERVER_ResponseAllocatorSetBufferAttributesFunction( + allocator_, OutputBufferAttributes), + "setting allocator's output buffer attributes function"); +} + +ModelInferCallbackHandler::~ModelInferCallbackHandler() +{ + LOG_TRITONSERVER_ERROR( + TRITONSERVER_ResponseAllocatorDelete(allocator_), + "deleting response allocator"); +} + +/** + * @brief Handles gRPC ModelInfer requests using the callback API pattern + * + * Request flow path: + * 1. Client creates and sends ModelInferRequest via gRPC + * 2. gRPC framework deserializes the protobuf message + * 3. gRPC calls this handler based on service registration + * 4. This function creates a callback state and reactor to manage async + * lifecycle + * 5. The Execute method initiates processing with proper ownership transfer + * + * Memory management: + * - CallbackState manages lifecycle of request/response objects + * - Ownership transfers to completion callbacks for async cleanup + * - Response memory allocation handled through allocator_ + * - Shared memory regions tracked and released after completion + * + * @param context The gRPC server context for this request + * @param request The deserialized ModelInferRequest from client + * @param response Output parameter for the ModelInferResponse to client + * @return ::grpc::ServerUnaryReactor* Reactor that signals request completion + */ +::grpc::ServerUnaryReactor* +ModelInferCallbackHandler::HandleModelInfer( + ::grpc::CallbackServerContext* context, + const inference::ModelInferRequest* request, + inference::ModelInferResponse* response) +{ + auto* reactor = context->DefaultReactor(); + + // Check preconditions + if (!ExecutePrecondition(context)) { + reactor->Finish(::grpc::Status( + ::grpc::StatusCode::UNAVAILABLE, "This protocol is restricted")); + return reactor; + } + + // Create callback state + auto callback_state = std::make_unique<CallbackState>( + response, reactor, context, tritonserver_); + + // Execute the request + Execute(context, request, response, reactor, callback_state); + + return reactor; +} + +void +ModelInferCallbackHandler::InferResponseComplete( + TRITONSERVER_InferenceResponse* response, const uint32_t flags, void* userp) +{ + LOG_VERBOSE(1) << "[InferResponseComplete START] Received userp " + "(CallbackState*) address: " + << userp; + std::unique_ptr<CallbackState> callback_state( + static_cast<CallbackState*>(userp)); + LOG_VERBOSE(1) << "[InferResponseComplete] CallbackState unique_ptr now owns " + "state at address: " + << callback_state.get(); + if (response != nullptr) { + // Use the pre-allocated response directly from the callback state + ::grpc::Status status = ::grpc::Status::OK; + + // Get the response from the payload's response queue as a fallback + LOG_VERBOSE(1) + << "[InferResponseComplete] Attempting to retrieve response pointer " + "directly from callback_state->response_ which points to: " + << callback_state->response_; + inference::ModelInferResponse* grpc_response = callback_state->response_; + + // If not available in callback state, try to get from response queue + if (grpc_response == nullptr) { + LOG_VERBOSE(1) + << "[InferResponseComplete] >>> Fallback Triggered! grpc_response " + "from state was NULL, attempting fallback from queue."; + grpc_response = callback_state->alloc_payload_.response_ptr_; + } + + if (grpc_response != nullptr) { + // Process the response + LOG_VERBOSE(1) + << "InferResponseComplete: Checking response object at address: " + << grpc_response; + TRITONSERVER_Error* err = InferResponseCompleteCommonCallback( + callback_state->tritonserver_.get(), response, *grpc_response, + callback_state->alloc_payload_); + + if (err != nullptr) { + GrpcStatusUtil::Create(&status, err); + TRITONSERVER_ErrorDelete(err); + } + } else { + status = ::grpc::Status( + ::grpc::StatusCode::INTERNAL, + "response object not found in callback"); + } + + // For callback API, we complete the RPC by finishing the reactor + // Only finish the reactor when we get the final response or on error + if ((flags & TRITONSERVER_RESPONSE_COMPLETE_FINAL) || !status.ok()) { + callback_state->reactor_->Finish(status); + } + } else { + // Handle null response case + callback_state->reactor_->Finish( + ::grpc::Status(::grpc::StatusCode::INTERNAL, "null response")); + } + +#ifdef TRITON_ENABLE_TRACING + if (callback_state->trace_ != nullptr) { + callback_state->trace_timestamps_.emplace_back(std::make_pair( + "INFER_RESPONSE_COMPLETE", TraceManager::CaptureTimestamp())); + } +#endif // TRITON_ENABLE_TRACING + + // Always delete the TRITONSERVER_InferenceResponse + if (response != nullptr) { + LOG_TRITONSERVER_ERROR( + TRITONSERVER_InferenceResponseDelete(response), + "deleting inference response"); + } +} + +bool +ModelInferCallbackHandler::ExecutePrecondition( + ::grpc::CallbackServerContext* context) +{ + if (!restricted_kv_.first.empty()) { + const auto& metadata = context->client_metadata(); + const auto it = metadata.find(restricted_kv_.first); + return (it != metadata.end()) && (it->second == restricted_kv_.second); + } + return true; +} + +// Implement the new private helper function +TRITONSERVER_Error* +ModelInferCallbackHandler::ForwardHeadersAsParametersCallback( + TRITONSERVER_InferenceRequest* irequest, + const ::grpc::CallbackServerContext* context) +{ + TRITONSERVER_Error* err = nullptr; + // Use the members stored in *this* specific handler instance + if (!header_forward_pattern_.empty()) { + const auto& metadata = + context->client_metadata(); // Use the passed context + for (const auto& pair : metadata) { + // Need to convert grpc::string_ref to std::string for RE2/Triton API + std::string key_str(pair.first.data(), pair.first.length()); + std::string value_str(pair.second.data(), pair.second.length()); + + // Use the regex member stored in *this* handler instance + if (RE2::PartialMatch(key_str, header_forward_regex_)) { + err = TRITONSERVER_InferenceRequestSetStringParameter( + irequest, key_str.c_str(), value_str.c_str()); + if (err != nullptr) { + break; // Exit loop on error + } + } + } + } + return err; +} + +void +ModelInferCallbackHandler::Execute( + ::grpc::CallbackServerContext* context, + const inference::ModelInferRequest* request, + inference::ModelInferResponse* response, + ::grpc::ServerUnaryReactor* reactor, + std::unique_ptr<CallbackState>& callback_state) +{ + TRITONSERVER_Error* err = nullptr; + TRITONSERVER_InferenceRequest* irequest = nullptr; + LOG_VERBOSE(1) << "[Execute START] Incoming response object address: " + << response; + // --- Step 1: Receive & Validate --- + int64_t requested_model_version; + err = GetModelVersionFromString( + request->model_version(), &requested_model_version); + + // Check if model has decoupled transaction policy (not supported by this RPC) + if (err == nullptr) { + uint32_t txn_flags; + // Query model properties + err = TRITONSERVER_ServerModelTransactionProperties( + tritonserver_.get(), request->model_name().c_str(), + requested_model_version, &txn_flags, nullptr /* voidp */); + if ((err == nullptr) && (txn_flags & TRITONSERVER_TXN_DECOUPLED) != 0) { + // Set error if decoupled + err = TRITONSERVER_ErrorNew( + TRITONSERVER_ERROR_UNSUPPORTED, + "ModelInfer RPC doesn't support models with decoupled " + "transaction policy"); + } + } + + // --- Step 2: Prepare Triton Request Object --- + if (err == nullptr) { + // Create the core Triton request object + err = TRITONSERVER_InferenceRequestNew( + &irequest, tritonserver_.get(), request->model_name().c_str(), + requested_model_version); + } + + // Populate request metadata (ID, sequence flags, priority, params, etc.) + if (err == nullptr) { + StateParameters state_params; // Temporary params for this call scope + err = SetInferenceRequestMetadata(irequest, *request, state_params); + } + + // Forward relevant gRPC headers as Triton parameters + if (err == nullptr) { + err = ForwardHeadersAsParametersCallback(irequest, context); + } + + // --- Step 3: Process Input Tensors --- + if (err == nullptr) { + // Parse inputs from request, handle shared memory (if any), + // serialize string data, and add data pointers/attributes to irequest. + // Serialized data stored in callback_state->serialized_data_ + // SHM info stored in callback_state->shm_regions_info_ + err = InferGRPCToInput( + tritonserver_, shm_manager_, *request, + &callback_state->serialized_data_, irequest, + &callback_state->shm_regions_info_); + } + + if (err == nullptr) { + // Use the externally provided response object directly. + // Store the external response pointer in the state for later access. + callback_state->response_ = response; + LOG_VERBOSE(1) << "[Execute] Stored response object address in " + "callback_state->response_: " + << callback_state->response_; + // Clear the externally provided response object directly. + response->Clear(); // Ensure it's empty before Triton writes to it + } + + // Prepare the allocator payload: info needed by allocation callback later. + if (err == nullptr) { + err = InferAllocatorPayloadCallback<inference::ModelInferResponse>( + tritonserver_, shm_manager_, *request, + std::move(callback_state->serialized_data_), callback_state->response_, + &callback_state->alloc_payload_, &callback_state->shm_regions_info_); + } + + // --- Step 5: Setup Automatic Cleanup Payloads & Register Callbacks --- + // Create payload for request release callback (manages irequest lifetime) + auto request_release_payload = std::make_unique<RequestReleasePayload>( + std::shared_ptr<TRITONSERVER_InferenceRequest>( + irequest, [](TRITONSERVER_InferenceRequest* r) { + // Custom deleter: Ensures delete is called via shared_ptr lifecycle + if (r != nullptr) { + LOG_TRITONSERVER_ERROR( + TRITONSERVER_InferenceRequestDelete(r), + "deleting inference request via shared_ptr custom deleter"); + } + })); + + // Register the release callback (cleans up request_release_payload & + // irequest) + if (err == nullptr) { + err = TRITONSERVER_InferenceRequestSetReleaseCallback( + irequest, InferRequestComplete, request_release_payload.get()); + } + + // Register the response callback (processes result, finishes RPC, cleans up + // callback_state) + if (err == nullptr) { + // Note: Passing callback_state.get() transfers potential ownership to the + // callback mechanism upon success (see step 7). + err = TRITONSERVER_InferenceRequestSetResponseCallback( + irequest, allocator_, &callback_state->alloc_payload_, + InferResponseComplete, callback_state.get()); + } + + // --- Optional: Setup Tracing --- + TRITONSERVER_InferenceTrace* triton_trace = nullptr; +#ifdef TRITON_ENABLE_TRACING + if (err == nullptr && trace_manager_ != nullptr) { + // Setup and start tracing if configured + GrpcServerCarrier carrier(context); + auto start_options = + trace_manager_->GetTraceStartOptions(carrier, request->model_name()); + callback_state->trace_ = + std::move(trace_manager_->SampleTrace(start_options)); + if (callback_state->trace_ != nullptr) { + triton_trace = callback_state->trace_->trace_; + } + } +#endif // TRITON_ENABLE_TRACING + + // Get request ID for logging, handle potential null irequest if error + // occurred early + const char* request_id_cstr = ""; + std::string request_id = "<unknown>"; + if (irequest != nullptr) { + auto id_err = TRITONSERVER_InferenceRequestId(irequest, &request_id_cstr); + if (id_err == nullptr && request_id_cstr != nullptr && + strlen(request_id_cstr) > 0) { + request_id = request_id_cstr; + } + TRITONSERVER_ErrorDelete(id_err); // Delete error from ID retrieval if any + } + + + // --- Step 6: Start Asynchronous Inference --- + if (err == nullptr) { + err = TRITONSERVER_ServerInferAsync( + tritonserver_.get(), irequest, triton_trace); + } + + // --- Step 7/8: Handle Outcome (Success or Error) --- + if (err == nullptr) { + // --- Success Path --- + // Inference successfully submitted to Triton core. + // Release ownership of payloads to the callback mechanism. + // Callbacks (InferResponseComplete, InferRequestComplete) are now + // responsible for cleanup. + LOG_VERBOSE(1) << "[Execute SUCCESS] Releasing ownership of callback_state " + "at address: " + << callback_state.get(); + callback_state.release(); + request_release_payload.release(); + // Execute function finishes here; gRPC call waits for reactor->Finish() in + // callback. + LOG_VERBOSE(1) << "[request id: " << request_id << "] " + << "Async inference submitted successfully."; + + } else { + // --- Error Path --- + // An error occurred during setup before submitting to Triton. + LOG_VERBOSE(1) << "[request id: " << request_id << "] " + << "Setup failed before submitting inference: " + << TRITONSERVER_ErrorMessage(err); + + // Create gRPC status from Triton error + ::grpc::Status status; + GrpcStatusUtil::Create(&status, err); + + // Perform explicit cleanup as callbacks won't run + TRITONSERVER_ErrorDelete(err); // Delete the primary Triton error + if (irequest != nullptr) { + // Explicitly delete the request object as the release callback won't run + // Note: The shared_ptr in request_release_payload will handle this + // gracefully + // when the unique_ptr goes out of scope below, due to the custom + // deleter. However, explicit deletion here is safe and clear. + LOG_TRITONSERVER_ERROR( + TRITONSERVER_InferenceRequestDelete(irequest), + "explicitly deleting inference request due to setup error"); + irequest = + nullptr; // Avoid potential double delete if shared_ptr logic changes + } + // Note: callback_state and request_release_payload unique_ptrs will + // automatically clean up their managed objects when they go out of + // scope now, as .release() was not called. + + // Immediately finish the gRPC call with the error status + reactor->Finish(status); + // Execute function finishes here. + } +} //=========================================================================== // The following section contains the handling mechanism for ModelInfer RPC. // This implementation is tuned towards performance and reducing latency. @@ -819,13 +1244,16 @@ ResponseAllocatorHelper( *actual_memory_type = preferred_memory_type; *actual_memory_type_id = preferred_memory_type_id; + LOG_VERBOSE(1) << "AllocatorHelper: Modifying response object at address: " + << response; // We add an output contents even if the 'byte_size' == 0 because we // expect to have a contents for every output. inference::ModelInferResponse::InferOutputTensor* output_tensor = response->add_outputs(); output_tensor->set_name(tensor_name); std::string* raw_output = response->add_raw_output_contents(); - + LOG_VERBOSE(1) << "AllocatorHelper: After add_outputs for " << tensor_name + << ", response->outputs_size() = " << response->outputs_size(); if (byte_size > 0) { const auto& pr = shm_map.find(tensor_name); if (pr != shm_map.end()) { diff --git a/src/grpc/infer_handler.h b/src/grpc/infer_handler.h index 40d5ce4806..dbcfb7dc53 100644 --- a/src/grpc/infer_handler.h +++ b/src/grpc/infer_handler.h @@ -34,13 +34,13 @@ #include <regex> #include <thread> +#include "../restricted_features.h" #include "../tracer.h" #include "grpc_handler.h" #include "grpc_service.grpc.pb.h" #include "grpc_utils.h" #include "triton/common/logging.h" #include "triton/core/tritonserver.h" - // Unique IDs are only needed when debugging. They only appear in // verbose logging. #ifndef NDEBUG @@ -353,6 +353,46 @@ struct AllocPayload { std::list<std::string> serialized_data_; }; +// +// AllocPayloadCallback +// +// Simple structure that carries the userp payload needed for +// allocation specifically for the Callback API, holding a direct +// pointer to the gRPC response object. +// +template <typename ResponseType> +struct AllocPayloadCallback { + using ClassificationMap = std::unordered_map<std::string, uint32_t>; + + // Constructor initializes the response pointer to null + explicit AllocPayloadCallback() + : response_ptr_(nullptr), response_alloc_count_(0) + { + } + + // Destructor - does nothing with response_ptr_ as ownership + // lies with the gRPC framework or CallbackState unique_ptr initially. + ~AllocPayloadCallback() = default; // Default destructor is sufficient + + // Direct pointer to the gRPC response object managed externally + // (by gRPC reactor or CallbackState). + ResponseType* response_ptr_; + + // Counter for allocations related to this payload. + uint32_t response_alloc_count_; + + // Map for shared memory information for output tensors. + TensorShmMap shm_map_; + + // Map for classification parameters for output tensors. + ClassificationMap classification_map_; + + // Used to extend the lifetime of serialized input data (e.g., for BYTES + // tensors) needed during the allocation phase (though data originates from + // the request). + std::list<std::string> serialized_data_; +}; + template <typename ResponseType> TRITONSERVER_Error* InferAllocatorPayload( @@ -430,6 +470,83 @@ InferAllocatorPayload( return nullptr; // Success } +template <typename ResponseType> +TRITONSERVER_Error* +InferAllocatorPayloadCallback( + const std::shared_ptr<TRITONSERVER_Server>& tritonserver, + const std::shared_ptr<SharedMemoryManager>& shm_manager, + const inference::ModelInferRequest& request, + std::list<std::string>&& serialized_data, + inference::ModelInferResponse* response_ptr, + AllocPayloadCallback<ResponseType>* alloc_payload, + std::vector<std::shared_ptr<const SharedMemoryManager::SharedMemoryInfo>>* + shm_regions_info) +{ + alloc_payload->response_ptr_ = response_ptr; + alloc_payload->shm_map_.clear(); + alloc_payload->classification_map_.clear(); + alloc_payload->serialized_data_ = std::move(serialized_data); + + // If any of the outputs use shared memory, then we must calculate + // the memory address for that output and store it in the allocator + // payload so that it is available when the allocation callback is + // invoked. + for (const auto& io : request.outputs()) { + std::string region_name; + int64_t offset; + size_t byte_size; + bool has_shared_memory; + RETURN_IF_ERR(ParseSharedMemoryParams< + inference::ModelInferRequest::InferRequestedOutputTensor>( + io, &has_shared_memory, ®ion_name, &offset, &byte_size)); + + bool has_classification; + uint32_t classification_count; + RETURN_IF_ERR(ParseClassificationParams( + io, &has_classification, &classification_count)); + + if (has_shared_memory && has_classification) { + return TRITONSERVER_ErrorNew( + TRITONSERVER_ERROR_INVALID_ARG, + "output can't set both 'shared_memory_region' and " + "'classification'"); + } + + if (has_shared_memory) { + void* base; + TRITONSERVER_MemoryType memory_type; + int64_t memory_type_id; + std::shared_ptr<const SharedMemoryManager::SharedMemoryInfo> shm_info = + nullptr; + RETURN_IF_ERR(shm_manager->GetMemoryInfo( + region_name, offset, byte_size, &base, &memory_type, &memory_type_id, + &shm_info)); + shm_regions_info->emplace_back(shm_info); + + if (memory_type == TRITONSERVER_MEMORY_GPU) { +#ifdef TRITON_ENABLE_GPU + char* cuda_handle; + RETURN_IF_ERR(shm_manager->GetCUDAHandle( + region_name, reinterpret_cast<cudaIpcMemHandle_t**>(&cuda_handle))); + alloc_payload->shm_map_.emplace( + io.name(), + ShmInfo(base, byte_size, memory_type, memory_type_id, cuda_handle)); +#endif + } else { + alloc_payload->shm_map_.emplace( + io.name(), ShmInfo( + base, byte_size, memory_type, memory_type_id, + nullptr /* cuda_ipc_handle */)); + } + } else if (has_classification) { + alloc_payload->classification_map_.emplace( + io.name(), classification_count); + } + } + + return nullptr; // Success +} + TRITONSERVER_Error* InferGRPCToInputHelper( const std::string& input_name, const std::string& model_name, const TRITONSERVER_DataType tensor_dt, const TRITONSERVER_DataType input_dt, @@ -694,6 +811,199 @@ InferResponseCompleteCommon( return nullptr; // success } +// Common function to populate the gRPC ModelInferResponse protobuf from the +// TRITONSERVER_InferenceResponse C structure. Handles metadata, parameters, +// output tensor data transfer, and classification formatting. Used by the +// callback API path. +template <typename ResponseType> +TRITONSERVER_Error* +InferResponseCompleteCommonCallback( + TRITONSERVER_Server* server, TRITONSERVER_InferenceResponse* iresponse, + inference::ModelInferResponse& response, + const AllocPayloadCallback<ResponseType>& alloc_payload) +{ + RETURN_IF_ERR(TRITONSERVER_InferenceResponseError(iresponse)); + + const char *model_name, *id; + int64_t model_version; + RETURN_IF_ERR(TRITONSERVER_InferenceResponseModel( + iresponse, &model_name, &model_version)); + RETURN_IF_ERR(TRITONSERVER_InferenceResponseId(iresponse, &id)); + + response.set_id(id); + response.set_model_name(model_name); + response.set_model_version(std::to_string(model_version)); + + // Propagate response parameters. + uint32_t parameter_count; + RETURN_IF_ERR(TRITONSERVER_InferenceResponseParameterCount( + iresponse, ¶meter_count)); + for (uint32_t pidx = 0; pidx < parameter_count; ++pidx) { + const char* name; + TRITONSERVER_ParameterType type; + const void* vvalue; + RETURN_IF_ERR(TRITONSERVER_InferenceResponseParameter( + iresponse, pidx, &name, &type, &vvalue)); + inference::InferParameter& param = (*response.mutable_parameters())[name]; + switch (type) { + case TRITONSERVER_PARAMETER_BOOL: + param.set_bool_param(*(reinterpret_cast<const bool*>(vvalue))); + break; + case TRITONSERVER_PARAMETER_INT: + param.set_int64_param(*(reinterpret_cast<const int64_t*>(vvalue))); + break; + case TRITONSERVER_PARAMETER_STRING: + param.set_string_param(reinterpret_cast<const char*>(vvalue)); + break; + case TRITONSERVER_PARAMETER_DOUBLE: + param.set_double_param(*(reinterpret_cast<const double*>(vvalue))); + break; + case TRITONSERVER_PARAMETER_BYTES: + return TRITONSERVER_ErrorNew( + TRITONSERVER_ERROR_UNSUPPORTED, + "Response parameter of type 'TRITONSERVER_PARAMETER_BYTES' is not " + "currently supported"); + break; + } + } + + // Go through each response output and transfer information to the + // corresponding GRPC response output. + uint32_t output_count; + RETURN_IF_ERR( + TRITONSERVER_InferenceResponseOutputCount(iresponse, &output_count)); + if (output_count != (uint32_t)response.outputs_size()) { + return TRITONSERVER_ErrorNew( + TRITONSERVER_ERROR_INTERNAL, "response output count mismatch"); + } + + for (uint32_t output_idx = 0; output_idx < output_count; ++output_idx) { + const char* cname; + TRITONSERVER_DataType datatype; + const int64_t* shape; + uint64_t dim_count; + const void* base; + size_t byte_size; + TRITONSERVER_MemoryType memory_type; + int64_t memory_type_id; + void* userp; + + RETURN_IF_ERR(TRITONSERVER_InferenceResponseOutput( + iresponse, output_idx, &cname, &datatype, &shape, &dim_count, &base, + &byte_size, &memory_type, &memory_type_id, &userp)); + + const std::string name(cname); + + // There are usually very few outputs so fastest just to look for + // the one we want... could create a map for cases where there are + // a large number of outputs. Or rely on order to be same... + inference::ModelInferResponse::InferOutputTensor* output = nullptr; + for (auto& io : *(response.mutable_outputs())) { + if (io.name() == name) { + output = &io; + break; + } + } + + if (output == nullptr) { + return TRITONSERVER_ErrorNew( + TRITONSERVER_ERROR_INTERNAL, + "unable to find expected response output"); + } + + // If this output was requested as classification then remove the + // raw output from the response and instead return classification + // results as a string tensor + const auto itr = alloc_payload.classification_map_.find(name); + if (itr == alloc_payload.classification_map_.end()) { + // Not classification... + output->set_datatype(TRITONSERVER_DataTypeString(datatype)); + for (size_t idx = 0; idx < dim_count; idx++) { + output->add_shape(shape[idx]); + } + } else { + // Classification + const uint32_t classification_count = itr->second; + + // For classification need to determine the batch size, if any, + // because need to use that to break up the response for each + // batch entry. + uint32_t batch_size = 0; + + uint32_t batch_flags; + RETURN_IF_ERR(TRITONSERVER_ServerModelBatchProperties( + server, model_name, model_version, &batch_flags, + nullptr /* voidp */)); + if ((dim_count > 0) && + ((batch_flags & TRITONSERVER_BATCH_FIRST_DIM) != 0)) { + batch_size = shape[0]; + } + + // Determine the batch1 byte size of the tensor... needed when + // the response tensor batch-size > 1 so that we know how to + // stride though the tensor data. + size_t batch1_element_count = 1; + for (size_t idx = ((batch_size == 0) ? 0 : 1); idx < dim_count; idx++) { + batch1_element_count *= shape[idx]; + } + + const size_t batch1_byte_size = + batch1_element_count * TRITONSERVER_DataTypeByteSize(datatype); + + // Create the classification contents + std::string serialized; + + size_t class_offset = 0; + for (uint32_t bs = 0; bs < std::max((uint32_t)1, batch_size); ++bs) { + std::vector<std::string> class_strs; + RETURN_IF_ERR(TopkClassifications( + iresponse, output_idx, + reinterpret_cast<const char*>(base) + class_offset, + ((class_offset + batch1_byte_size) > byte_size) ? 0 + : batch1_byte_size, + datatype, classification_count, &class_strs)); + + // Serialize for binary representation... + for (const auto& str : class_strs) { + uint32_t len = str.size(); + serialized.append(reinterpret_cast<const char*>(&len), sizeof(len)); + if (len > 0) { + serialized.append(str); + } + } + + class_offset += batch1_byte_size; + } + + // Update the output with new datatype, shape and contents. + output->set_datatype( + TRITONSERVER_DataTypeString(TRITONSERVER_TYPE_BYTES)); + + if (batch_size > 0) { + output->add_shape(batch_size); + } + output->add_shape( + std::min(classification_count, (uint32_t)batch1_element_count)); + + (*response.mutable_raw_output_contents())[output_idx] = + std::move(serialized); + } + } + + // Make sure response doesn't exceed GRPC limits. + if (response.ByteSizeLong() > MAX_GRPC_MESSAGE_SIZE) { + return TRITONSERVER_ErrorNew( + TRITONSERVER_ERROR_INVALID_ARG, + std::string( + "Response has byte size " + + std::to_string(response.ByteSizeLong()) + + " which exceeds gRPC's byte size limit " + std::to_string(INT_MAX) + + ".") + .c_str()); + } + + return nullptr; // success +} // // InferHandlerState // @@ -982,7 +1292,7 @@ class InferHandlerState { // FIXME: Is there a better way to put task on the // completion queue rather than using alarm object? // The alarm object will add a new task to the back of the - // completion queue when it expires or when it’s cancelled. + // completion queue when it expires or when it's cancelled. state->alarm_.Set( cq_, gpr_now(gpr_clock_type::GPR_CLOCK_REALTIME), state); } @@ -1596,6 +1906,81 @@ InferHandler<ServiceType, ServerResponderType, RequestType, ResponseType>:: return err; } +class ModelInferCallbackHandler { + public: + ModelInferCallbackHandler( + const std::string& name, + const std::shared_ptr<TRITONSERVER_Server>& tritonserver, + TraceManager* trace_manager, + const std::shared_ptr<SharedMemoryManager>& shm_manager, + grpc_compression_level compression_level, + RestrictedFeatures& restricted_keys, + const std::string& forward_header_pattern); + + ~ModelInferCallbackHandler(); + + ::grpc::ServerUnaryReactor* HandleModelInfer( + ::grpc::CallbackServerContext* context, + const inference::ModelInferRequest* request, + inference::ModelInferResponse* response); + + private: + // Define CallbackState first, before any methods that use it + struct CallbackState { + CallbackState( + inference::ModelInferResponse* response, + ::grpc::ServerUnaryReactor* reactor, + ::grpc::CallbackServerContext* context, + const std::shared_ptr<TRITONSERVER_Server>& tritonserver) + : response_(response), reactor_(reactor), context_(context), + tritonserver_(tritonserver) + { + } + + inference::ModelInferResponse* response_; + ::grpc::ServerUnaryReactor* reactor_; + ::grpc::CallbackServerContext* context_; + std::shared_ptr<TRITONSERVER_Server> tritonserver_; + + // Request resources + AllocPayloadCallback<inference::ModelInferResponse> alloc_payload_; + std::list<std::string> serialized_data_; + std::vector<std::shared_ptr<const SharedMemoryManager::SharedMemoryInfo>> + shm_regions_info_; + +#ifdef TRITON_ENABLE_TRACING + std::shared_ptr<TraceManager::Trace> trace_; +#endif // TRITON_ENABLE_TRACING + }; + + TRITONSERVER_Error* ForwardHeadersAsParametersCallback( + TRITONSERVER_InferenceRequest* irequest, + const ::grpc::CallbackServerContext* context); + // Now Execute can use CallbackState + void Execute( + ::grpc::CallbackServerContext* context, + const inference::ModelInferRequest* request, + inference::ModelInferResponse* response, + ::grpc::ServerUnaryReactor* reactor, + std::unique_ptr<CallbackState>& callback_state); + + static void InferResponseComplete( + TRITONSERVER_InferenceResponse* response, const uint32_t flags, + void* userp); + + bool ExecutePrecondition(::grpc::CallbackServerContext* context); + + const std::string name_; + std::shared_ptr<TRITONSERVER_Server> tritonserver_; + TraceManager* trace_manager_; + std::shared_ptr<SharedMemoryManager> shm_manager_; + TRITONSERVER_ResponseAllocator* allocator_; + + grpc_compression_level compression_level_; + const std::pair<std::string, std::string> restricted_kv_; + const std::string header_forward_pattern_; + re2::RE2 header_forward_regex_; +}; // // ModelInferHandler // diff --git a/test_grpc_callbacks.sh b/test_grpc_callbacks.sh new file mode 100755 index 0000000000..8058af5cdd --- /dev/null +++ b/test_grpc_callbacks.sh @@ -0,0 +1,148 @@ +#!/bin/bash +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions +# are met: +# * Redistributions of source code must retain the above copyright +# notice, this list of conditions and the following disclaimer. +# * Redistributions in binary form must reproduce the above copyright +# notice, this list of conditions and the following disclaimer in the +# documentation and/or other materials provided with the distribution. +# * Neither the name of NVIDIA CORPORATION nor the names of its +# contributors may be used to endorse or promote products derived +# from this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY +# EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR +# PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR +# CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, +# EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, +# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR +# PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY +# OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +# Note: Before running this script, start Triton server in explicit model control mode: +# tritonserver --model-repository=/path/to/model/repository --model-control-mode=explicit + +# Default server URL +SERVER_URL=${1:-"localhost:8001"} +PROTO_PATH="/mnt/builddir/triton-server/_deps/repo-common-src/protobuf" +PROTO_FILE="${PROTO_PATH}/grpccallback_service.proto" +HEALTH_PROTO="${PROTO_PATH}/health.proto" + +# Colors for output +GREEN='\033[0;32m' +RED='\033[0;31m' +NC='\033[0m' # No Color +BOLD='\033[1m' + +# Function to print test results +print_result() { + local test_name=$1 + local result=$2 + if [ $result -eq 0 ]; then + echo -e "${test_name}: ${GREEN}PASSED${NC}" + else + echo -e "${test_name}: ${RED}FAILED${NC}" + fi +} + +echo -e "\n${BOLD}Testing gRPC Callback RPCs against ${SERVER_URL}${NC}\n" + +# Test Health Check +echo -e "\n${BOLD}Testing Health Check:${NC}" +grpcurl -proto ${HEALTH_PROTO} \ + --import-path ${PROTO_PATH} \ + -plaintext ${SERVER_URL} \ + grpc.health.v1.Health/Check +print_result "Health Check" $? + +# Test Repository Index +echo -e "\n${BOLD}Testing Repository Index:${NC}" +grpcurl -proto ${PROTO_FILE} \ + --import-path ${PROTO_PATH} \ + -plaintext ${SERVER_URL} \ + inference.GRPCInferenceServiceCallback/RepositoryIndex +print_result "Repository Index" $? + +# Test Model Load +echo -e "\n${BOLD}Testing Model Load:${NC}" +grpcurl -proto ${PROTO_FILE} \ + --import-path ${PROTO_PATH} \ + -plaintext -d '{"model_name": "simple"}' \ + ${SERVER_URL} \ + inference.GRPCInferenceServiceCallback/RepositoryModelLoad +print_result "Model Load" $? + +# Wait for model to load +sleep 2 + +# Test Model Unload +echo -e "\n${BOLD}Testing Model Unload:${NC}" +grpcurl -proto ${PROTO_FILE} \ + --import-path ${PROTO_PATH} \ + -plaintext -d '{"model_name": "simple"}' \ + ${SERVER_URL} \ + inference.GRPCInferenceServiceCallback/RepositoryModelUnload +print_result "Model Unload" $? + +# Test Server Live +echo -e "\n${BOLD}Testing Server Live:${NC}" +grpcurl -proto ${PROTO_FILE} \ + --import-path ${PROTO_PATH} \ + -plaintext ${SERVER_URL} \ + inference.GRPCInferenceServiceCallback/ServerLive +print_result "Server Live" $? + +# Test Server Ready +echo -e "\n${BOLD}Testing Server Ready:${NC}" +grpcurl -proto ${PROTO_FILE} \ + --import-path ${PROTO_PATH} \ + -plaintext ${SERVER_URL} \ + inference.GRPCInferenceServiceCallback/ServerReady +print_result "Server Ready" $? + +# Load model again before testing Model Ready +echo -e "\n${BOLD}Loading model for Model Ready test:${NC}" +grpcurl -proto ${PROTO_FILE} \ + --import-path ${PROTO_PATH} \ + -plaintext -d '{"model_name": "simple"}' \ + ${SERVER_URL} \ + inference.GRPCInferenceServiceCallback/RepositoryModelLoad +print_result "Model Load" $? + +# Wait for model to load +sleep 2 + +# Test Model Ready +echo -e "\n${BOLD}Testing Model Ready:${NC}" +grpcurl -proto ${PROTO_FILE} \ + --import-path ${PROTO_PATH} \ + -plaintext -d '{"name": "simple"}' \ + ${SERVER_URL} \ + inference.GRPCInferenceServiceCallback/ModelReady +print_result "Model Ready" $? + +# Test Server Metadata +echo -e "\n${BOLD}Testing Server Metadata:${NC}" +grpcurl -proto ${PROTO_FILE} \ + --import-path ${PROTO_PATH} \ + -plaintext ${SERVER_URL} \ + inference.GRPCInferenceServiceCallback/ServerMetadata +print_result "Server Metadata" $? + +# Test Model Metadata +echo -e "\n${BOLD}Testing Model Metadata:${NC}" +grpcurl -proto ${PROTO_FILE} \ + --import-path ${PROTO_PATH} \ + -plaintext -d '{"name": "simple"}' \ + ${SERVER_URL} \ + inference.GRPCInferenceServiceCallback/ModelMetadata +print_result "Model Metadata" $? + +echo -e "\n${BOLD}Test Summary:${NC}" +echo "----------------------------------------" \ No newline at end of file