diff --git a/src/libtorch.cc b/src/libtorch.cc index 99732cf..6ac3536 100644 --- a/src/libtorch.cc +++ b/src/libtorch.cc @@ -105,7 +105,10 @@ class ModelState : public BackendModel { bool EnabledCacheCleaning() { return enable_cache_cleaning_; } bool EnabledWeightSharing() { return enable_weight_sharing_; } - const std::vector& ModelOutputs() { return output_names_; } + const std::map>& ModelOutputs() + { + return model_outputs_; + } private: ModelState(TRITONBACKEND_Model* triton_model); @@ -145,9 +148,14 @@ class ModelState : public BackendModel { std::pair, std::shared_ptr> torch_models_; - // List of all the outputs specified in the output section of model - // configuration. - std::vector output_names_; + // model_outputs is a map that contains unique outputs that the model must + // provide. The first pair is the model output index and the second is + // the index in the model state, -1 is used if one is not required. + // In the model configuration, the output in the state configuration + // can have intersection with the outputs section of the model. If an output + // is specified both in the output section and state section, it indicates + // that the backend must return the output state to the client too. + std::map> model_outputs_; }; TRITONSERVER_Error* @@ -172,6 +180,49 @@ ModelState::Create(TRITONBACKEND_Model* triton_model, ModelState** state) RETURN_IF_ERROR((*state)->SetModelConfig()); } + auto& model_outputs = (*state)->model_outputs_; + // Parse the output states in the model configuration + triton::common::TritonJson::Value sequence_batching; + if ((*state)->ModelConfig().Find("sequence_batching", &sequence_batching)) { + triton::common::TritonJson::Value states; + if (sequence_batching.Find("state", &states)) { + for (size_t i = 0; i < states.ArraySize(); i++) { + triton::common::TritonJson::Value state; + RETURN_IF_ERROR(states.IndexAsObject(i, &state)); + std::string output_state_name; + RETURN_IF_ERROR( + state.MemberAsString("output_name", &output_state_name)); + auto it = model_outputs.find(output_state_name); + if (it == model_outputs.end()) { + model_outputs.insert({output_state_name, std::make_pair(-1, i)}); + } else { + it->second.second = i; + } + } + } + } + + // Parse the output names in the model configuration + triton::common::TritonJson::Value outputs; + RETURN_IF_ERROR((*state)->ModelConfig().MemberAsArray("output", &outputs)); + for (size_t i = 0; i < outputs.ArraySize(); i++) { + triton::common::TritonJson::Value output; + THROW_IF_BACKEND_INSTANCE_ERROR(outputs.IndexAsObject(i, &output)); + + // Use names from ModelConfig by reference since the model + // config will persist longer than this inference execution. + std::string output_name; + THROW_IF_BACKEND_INSTANCE_ERROR( + output.MemberAsString("name", &output_name)); + + auto it = model_outputs.find(output_name); + if (it == model_outputs.end()) { + model_outputs.insert({output_name, std::make_pair(i, -1)}); + } else { + it->second.first = i; + } + } + RETURN_IF_ERROR((*state)->ParseParameters()); return nullptr; // success @@ -185,22 +236,6 @@ ModelState::ModelState(TRITONBACKEND_Model* triton_model) enable_jit_executor_pair_({false, true}), enable_nvfuser_pair_({false, false}) { - output_names_.clear(); - - triton::common::TritonJson::Value ios; - THROW_IF_BACKEND_INSTANCE_ERROR(ModelConfig().MemberAsArray("output", &ios)); - for (size_t i = 0; i < ios.ArraySize(); i++) { - triton::common::TritonJson::Value io; - THROW_IF_BACKEND_INSTANCE_ERROR(ios.IndexAsObject(i, &io)); - - // Use names from ModelConfig by reference since the model - // config will persist longer than this inference execution. - const char* io_name; - size_t io_name_len; - THROW_IF_BACKEND_INSTANCE_ERROR( - io.MemberAsString("name", &io_name, &io_name_len)); - output_names_.emplace_back(io_name); - } } TRITONSERVER_Error* @@ -698,6 +733,11 @@ ModelInstanceState::ModelInstanceState( if (have_corrid) { expected_input_cnt += 1; } + // Add the state inputs to the expected count + triton::common::TritonJson::Value states; + if (sequence_batching.Find("state", &states)) { + expected_input_cnt += states.ArraySize(); + } } supports_batching_ = model_state_->MaxBatchSize() > 0; @@ -991,6 +1031,47 @@ ModelInstanceState::ValidateInputs(const size_t expected_input_cnt) } } } + triton::common::TritonJson::Value sequence_batching; + if (model_state_->ModelConfig().Find( + "sequence_batching", &sequence_batching)) { + triton::common::TritonJson::Value states; + if (sequence_batching.Find("state", &states)) { + for (size_t i = 0; i < states.ArraySize(); i++) { + triton::common::TritonJson::Value state; + RETURN_IF_ERROR(states.IndexAsObject(i, &state)); + std::string state_name; + RETURN_IF_ERROR(state.MemberAsString("input_name", &state_name)); + AddInputToMap(naming_convention, allowed_inputs, state_name, i); + + // Validate data type + std::string state_dtype; + RETURN_IF_ERROR(state.MemberAsString("data_type", &state_dtype)); + const auto pr = ModelConfigDataTypeToTorchType(state_dtype); + if (!pr.first && (state_dtype != "TYPE_STRING")) { + return TRITONSERVER_ErrorNew( + TRITONSERVER_ERROR_INTERNAL, + ("unsupported datatype " + state_dtype + " for input state '" + + state_name + "' for model '" + model_state_->Name() + "'") + .c_str()); + } + + // Validate shape for String inputs. Only allow 1 dimension. + if (state_dtype == "TYPE_STRING") { + std::vector dims; + if ((dims.size() + (supports_batching_ ? 1 : 0)) > 1) { + return TRITONSERVER_ErrorNew( + TRITONSERVER_ERROR_INTERNAL, + ("Triton only supports 1 dimensional List of String as input " + "for " + "'" + + std::string(state_name) + "' for model '" + + model_state_->Name() + "'") + .c_str()); + } + } + } + } + } triton::common::TritonJson::Value batch_inputs; RETURN_IF_ERROR( @@ -1085,6 +1166,54 @@ ModelInstanceState::ValidateOutputs() output_dtype_map_[io_name] = ConvertTorchTypeToDataType(pr.second); } + triton::common::TritonJson::Value sequence_batching; + if (model_state_->ModelConfig().Find( + "sequence_batching", &sequence_batching)) { + triton::common::TritonJson::Value states; + if (sequence_batching.Find("state", &states)) { + for (size_t i = 0; i < states.ArraySize(); i++) { + triton::common::TritonJson::Value state; + RETURN_IF_ERROR(states.IndexAsObject(i, &state)); + std::string state_name; + RETURN_IF_ERROR(state.MemberAsString("output_name", &state_name)); + std::string state_dtype; + RETURN_IF_ERROR(state.MemberAsString("data_type", &state_dtype)); + std::vector dims; + RETURN_IF_ERROR(ParseShape(state, "dims", &dims)); + + // For state, naming convention is enforced to be NAMED_INDEX + int start_pos = state_name.find(deliminator); + op_index = std::atoi(state_name.substr(start_pos + 2).c_str()); + + const auto pr = ModelConfigDataTypeToTorchType(state_dtype); + if (!pr.first && (state_dtype != "TYPE_STRING")) { + return TRITONSERVER_ErrorNew( + TRITONSERVER_ERROR_INTERNAL, + ("unsupported datatype " + state_dtype + " for state '" + + state_name + "' for model '" + model_state_->Name() + "'") + .c_str()); + } + + // Validate shape for String outputs. Only allow 1 dimension. + if (state_dtype == "TYPE_STRING") { + if ((dims.size() + (supports_batching_ ? 1 : 0)) > 1) { + return TRITONSERVER_ErrorNew( + TRITONSERVER_ERROR_INTERNAL, + ("Triton only supports 1 dimensional List of String as output " + "for " + "'" + + std::string(state_name) + "' for model '" + + model_state_->Name() + "'") + .c_str()); + } + } + + output_index_map_[state_name] = op_index; + output_dtype_map_[state_name] = ConvertTorchTypeToDataType(pr.second); + } + } + } + return nullptr; // success } @@ -1274,14 +1403,14 @@ ModelInstanceState::ProcessRequests( if (!all_response_failed) { for (const auto& name : model_state_->ModelOutputs()) { - int op_index = output_index_map_[name]; + int op_index = output_index_map_[name.first]; if ((op_index < 0) || (op_index > max_index)) { RESPOND_ALL_AND_SET_TRUE_IF_ERROR( responses, request_count, all_response_failed, TRITONSERVER_ErrorNew( TRITONSERVER_ERROR_INVALID_ARG, std::string( - "The output " + std::string(name) + + "The output " + std::string(name.first) + " in the model configuration refers to an output index " "which doesn't exist. This model has " + std::to_string(max_index + 1) + " outputs") @@ -1608,6 +1737,61 @@ ModelInstanceState::GetNamingConvention( } } + triton::common::TritonJson::Value sequence_batching; + if (model_state_->ModelConfig().Find( + "sequence_batching", &sequence_batching)) { + // If we need to manage state for the model, then we need to check + // the naming of the state adheres to both the input and output conventions + triton::common::TritonJson::Value states; + if (sequence_batching.Find("state", &states)) { + if (*naming_convention != NamingConvention::NAMED_INDEX) { + return TRITONSERVER_ErrorNew( + TRITONSERVER_ERROR_INVALID_ARG, + ("PyTorch model '" + model_state_->Name() + + "' is using sequence batching with state but not all inputs and " + "outputs follow the __ naming convention. ") + .c_str()); + } + } + + for (size_t i = 0; i < states.ArraySize(); i++) { + triton::common::TritonJson::Value state; + RETURN_IF_ERROR(states.IndexAsObject(i, &state)); + std::string name_entry = + io_kind == "input" ? "input_name" : "output_name"; + std::string state_name; + RETURN_IF_ERROR(state.MemberAsString(name_entry.c_str(), &state_name)); + int start_pos = state_name.find(deliminator); + if (start_pos == -1) { + return TRITONSERVER_ErrorNew( + TRITONSERVER_ERROR_INVALID_ARG, + ("PyTorch model '" + model_state_->Name() + + "' is using sequence batching with state but state '" + + state_name + + "' does not follow the __ naming convention. ") + .c_str()); + } else { + // check if the index part of the name is not an integer + std::string index_str = state_name.substr(start_pos + 2); + bool is_int = true; + for (auto itr = index_str.begin(); itr != index_str.end(); itr++) { + if (std::isdigit(*itr) == 0) { + is_int = false; + } + } + if (!is_int) { + return TRITONSERVER_ErrorNew( + TRITONSERVER_ERROR_INVALID_ARG, + ("PyTorch model '" + model_state_->Name() + + "' is using sequence batching with state but state '" + + state_name + + "' does not follow the __ naming convention. ") + .c_str()); + } + } + } + } + return nullptr; // success } @@ -1789,10 +1973,11 @@ SetStringInputTensor( } bool -SetStringOutputBuffer( +SetStringBuffer( torch::List* tensor, TRITONBACKEND_Response** response, - TRITONBACKEND_Output* response_output, const size_t tensor_element_count, - cudaStream_t stream, std::string* serialized) + TRITONBACKEND_Output* response_output, TRITONBACKEND_State* response_state, + const size_t tensor_element_count, cudaStream_t stream, + std::string* serialized, bool state) { bool cuda_copy = false; @@ -1814,15 +1999,26 @@ SetStringOutputBuffer( TRITONSERVER_MemoryType actual_memory_type = TRITONSERVER_MEMORY_CPU; int64_t actual_memory_type_id = 0; + TRITONSERVER_Error* err; void* buffer; - auto err = TRITONBACKEND_OutputBuffer( - response_output, &buffer, serialized->size(), &actual_memory_type, - &actual_memory_type_id); - if (err != nullptr) { - RESPOND_AND_SET_NULL_IF_ERROR(response, err); - return cuda_copy; - } + if (!state) { + auto err = TRITONBACKEND_OutputBuffer( + response_output, &buffer, serialized->size(), &actual_memory_type, + &actual_memory_type_id); + if (err != nullptr) { + RESPOND_AND_SET_NULL_IF_ERROR(response, err); + return cuda_copy; + } + } else { + auto err = TRITONBACKEND_StateBuffer( + response_state, &buffer, serialized->size(), &actual_memory_type, + &actual_memory_type_id); + if (err != nullptr) { + RESPOND_AND_SET_NULL_IF_ERROR(response, err); + return cuda_copy; + } + } // Copy the serialized tensor into the allocated buffer. bool cuda_used = false; err = CopyBuffer( @@ -1837,9 +2033,38 @@ SetStringOutputBuffer( return cuda_copy; } + if (state) { + RESPOND_AND_SET_NULL_IF_ERROR( + response, TRITONBACKEND_StateUpdate(response_state)); + } + return cuda_copy; } + +bool +SetStringOutputBuffer( + torch::List* tensor, TRITONBACKEND_Response** response, + TRITONBACKEND_Output* response_output, const size_t tensor_element_count, + cudaStream_t stream, std::string* serialized) +{ + return SetStringBuffer( + tensor, response, response_output, nullptr /* response_state */, + tensor_element_count, stream, serialized, false /* state */); +} + +bool +SetStringStateBuffer( + torch::List* tensor, TRITONBACKEND_Response** response, + TRITONBACKEND_State* response_state, const size_t tensor_element_count, + cudaStream_t stream, std::string* serialized) +{ + return SetStringBuffer( + tensor, response, nullptr /* response_output */, response_state, + tensor_element_count, stream, serialized, true /* state */); +} + + TRITONSERVER_Error* ModelInstanceState::SetInputTensors( size_t total_batch_size, TRITONBACKEND_Request** requests, @@ -2026,9 +2251,10 @@ ModelInstanceState::ReadOutputTensors( bool cuda_copy = false; // The serialized string buffer must be valid until output copies are done std::vector> string_buffer; - for (size_t idx = 0; idx < model_state_->ModelOutputs().size(); idx++) { - std::string name = model_state_->ModelOutputs()[idx]; - int op_index = output_index_map_[name]; + for (auto& output : model_state_->ModelOutputs()) { + int op_index = output_index_map_[output.first]; + auto name = output.first; + auto output_tensor_pair = output.second; if (output_tensors[op_index].isTensor()) { torch::Tensor output_flat; @@ -2086,10 +2312,22 @@ ModelInstanceState::ReadOutputTensors( "' is a scalar which is not supported.") .c_str()); } + if (output_tensor_pair.first != -1) { + responder.ProcessTensor( + name, output_dtype, batchn_shape, output_buffer, memory_type, + memory_id); + } + if (output_tensor_pair.second != -1) { + std::vector states; + states = responder.ProcessStateTensor( + name, output_dtype, batchn_shape, output_buffer, memory_type, + memory_id); + // Update the states + for (auto& state : states) { + RETURN_IF_ERROR(TRITONBACKEND_StateUpdate(state)); + } + } - responder.ProcessTensor( - name, output_dtype, batchn_shape, output_buffer, memory_type, - memory_id); } else { responder.ProcessBatchOutput( name, *batch_output, output_buffer, memory_type, memory_id); @@ -2119,15 +2357,30 @@ ModelInstanceState::ReadOutputTensors( // Only need an response tensor for requested outputs. if (response != nullptr) { - TRITONBACKEND_Output* response_output; + if (output_tensor_pair.first != -1) { + TRITONBACKEND_Output* response_output; + RESPOND_AND_SET_NULL_IF_ERROR( + &response, TRITONBACKEND_ResponseOutput( + response, &response_output, name.c_str(), + TRITONSERVER_TYPE_BYTES, batchn_shape.data(), + batchn_shape.size())); + string_buffer.emplace_back(new std::string()); + cuda_copy |= SetStringOutputBuffer( + &output_list, &response, response_output, tensor_element_cnt, + GetCudaStreamByInstanceKind(), string_buffer.back().get()); + } + } + if (output_tensor_pair.second != -1) { + TRITONBACKEND_State* response_state; RESPOND_AND_SET_NULL_IF_ERROR( - &response, TRITONBACKEND_ResponseOutput( - response, &response_output, name.c_str(), + &response, TRITONBACKEND_StateNew( + &response_state, request, name.c_str(), TRITONSERVER_TYPE_BYTES, batchn_shape.data(), batchn_shape.size())); + string_buffer.emplace_back(new std::string()); - cuda_copy |= SetStringOutputBuffer( - &output_list, &response, response_output, tensor_element_cnt, + cuda_copy |= SetStringStateBuffer( + &output_list, &response, response_state, tensor_element_cnt, GetCudaStreamByInstanceKind(), string_buffer.back().get()); } }