diff --git a/src/libtorch.cc b/src/libtorch.cc index 7e0d288..16ff2c1 100644 --- a/src/libtorch.cc +++ b/src/libtorch.cc @@ -25,7 +25,10 @@ // OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. #include + +#include #include + #include "libtorch_utils.h" #include "triton/backend/backend_common.h" #include "triton/backend/backend_input_collector.h" @@ -502,6 +505,10 @@ class ModelInstanceState : public BackendModelInstance { triton::common::TritonJson::Value& sequence_batching, const std::string& control_kind, bool required, bool* have_control); TRITONSERVER_Error* ValidateInputs(const size_t expected_input_cnt); + void AddInputToMap( + NamingConvention naming_convention, + const std::vector allowed_inputs, const std::string& io_name, + const uint32_t index); TRITONSERVER_Error* ValidateOutputs(); void Execute( std::vector* responses, @@ -538,6 +545,7 @@ class ModelInstanceState : public BackendModelInstance { // Map from configuration name for an input to the index of // that input in the model. std::unordered_map input_index_map_; + uint32_t batch_input_count_ = 0; // Map from configuration name for an output to the index of // that output in the model. @@ -607,6 +615,12 @@ ModelInstanceState::ModelInstanceState( if (model_state->ModelConfig().Find("input", &inputs)) { expected_input_cnt = inputs.ArraySize(); } + + triton::common::TritonJson::Value config_batch_inputs; + if (model_state->ModelConfig().Find("batch_input", &config_batch_inputs)) { + batch_input_count_ = config_batch_inputs.ArraySize(); + expected_input_cnt += batch_input_count_; + } } // If this is a sequence model then make sure that the required @@ -757,6 +771,43 @@ ModelInstanceState::ValidateTypedSequenceControl( return nullptr; // success } +void +ModelInstanceState::AddInputToMap( + NamingConvention naming_convention, + const std::vector allowed_inputs, const std::string& io_name, + const uint32_t index) +{ + std::string deliminator = "__"; + + if (is_dict_input_) { + // If dictionary, index is irrelevant but we use the map to store the + // input names since they are the keys for the dictionary + input_index_map_[io_name] = index; + } else { + switch (naming_convention) { + case NamingConvention::FORWARD_ARGUMENT: { + auto itr = + std::find(allowed_inputs.begin(), allowed_inputs.end(), io_name); + if (itr != allowed_inputs.end()) { + input_index_map_[io_name] = + std::distance(allowed_inputs.begin(), itr); + } + return; + } + case NamingConvention::NAMED_INDEX: { + int start_pos = io_name.find(deliminator); + int ip_index = std::atoi(io_name.substr(start_pos + 2).c_str()); + input_index_map_[io_name] = ip_index; + return; + } + case NamingConvention::STRICT_CONFIG_ORDERING: { + input_index_map_[io_name] = index; + return; + } + } + } +} + TRITONSERVER_Error* ModelInstanceState::ValidateInputs(const size_t expected_input_cnt) { @@ -822,8 +873,6 @@ ModelInstanceState::ValidateInputs(const size_t expected_input_cnt) triton::common::TritonJson::Value ios; RETURN_IF_ERROR(model_state_->ModelConfig().MemberAsArray("input", &ios)); - std::string deliminator = "__"; - int ip_index = 0; if (ios.ArraySize() == 0) { return TRITONSERVER_ErrorNew( @@ -842,34 +891,7 @@ ModelInstanceState::ValidateInputs(const size_t expected_input_cnt) // Validate name std::string io_name; RETURN_IF_ERROR(io.MemberAsString("name", &io_name)); - if (is_dict_input_) { - // If dictionary, index is irrelevant but we use the map to store the - // input names since they are the keys for the dictionary - input_index_map_[io_name] = i; - } else { - switch (naming_convention) { - case NamingConvention::FORWARD_ARGUMENT: { - auto itr = - std::find(allowed_inputs.begin(), allowed_inputs.end(), io_name); - if (itr != allowed_inputs.end()) { - input_index_map_[io_name] = - std::distance(allowed_inputs.begin(), itr); - } - break; - } - case NamingConvention::NAMED_INDEX: { - int start_pos = io_name.find(deliminator); - ip_index = std::atoi(io_name.substr(start_pos + 2).c_str()); - input_index_map_[io_name] = ip_index; - break; - } - case NamingConvention::STRICT_CONFIG_ORDERING: { - input_index_map_[io_name] = i; - break; - } - } - } - + AddInputToMap(naming_convention, allowed_inputs, io_name, i); // Validate data type std::string io_dtype; RETURN_IF_ERROR(io.MemberAsString("data_type", &io_dtype)); @@ -906,6 +928,18 @@ ModelInstanceState::ValidateInputs(const size_t expected_input_cnt) } } + triton::common::TritonJson::Value batch_inputs; + RETURN_IF_ERROR( + model_state_->ModelConfig().MemberAsArray("batch_input", &batch_inputs)); + size_t i = 0; + for (const auto& batch_input : StateForModel()->BatchInputs()) { + for (const auto& input_name : batch_input.TargetNames()) { + AddInputToMap( + naming_convention, allowed_inputs, input_name, i + ios.ArraySize()); + i++; + } + } + return nullptr; // success } @@ -1312,12 +1346,12 @@ ModelInstanceState::Execute( torch::jit::overrideCanFuseOnCPU(false); torch::jit::overrideCanFuseOnGPU(false); torch::jit::setTensorExprFuserEnabled(false); - torch::jit::fuser::cuda::setEnabled(true); + torch::jit::fuser::cuda::setEnabled(true); } else { torch::jit::overrideCanFuseOnCPU(true); torch::jit::overrideCanFuseOnGPU(true); torch::jit::setTensorExprFuserEnabled(true); - torch::jit::fuser::cuda::setEnabled(false); + torch::jit::fuser::cuda::setEnabled(false); } } @@ -1725,7 +1759,8 @@ ModelInstanceState::SetInputTensors( // request as the representative for the input tensors. uint32_t input_count; RETURN_IF_ERROR(TRITONBACKEND_RequestInputCount(requests[0], &input_count)); - input_tensors->resize(input_count); + + input_tensors->resize(input_count + batch_input_count_); for (uint32_t input_idx = 0; input_idx < input_count; input_idx++) { TRITONBACKEND_Input* input; RETURN_IF_ERROR( @@ -1761,9 +1796,9 @@ ModelInstanceState::SetInputTensors( batchn_shape[0] += GetElementCount(input_shape, input_dims_count); } - } - else { - batchn_shape = std::vector(input_shape, input_shape + input_dims_count); + } else { + batchn_shape = + std::vector(input_shape, input_shape + input_dims_count); if (supports_batching_) { batchn_shape[0] = total_batch_size; } @@ -1828,6 +1863,36 @@ ModelInstanceState::SetInputTensors( } } + for (const auto& batch_input : StateForModel()->BatchInputs()) { + std::vector shape; + collector->BatchInputShape(batch_input, &shape); + + for (const auto& input_name : batch_input.TargetNames()) { + input_names->emplace_back(input_name.c_str()); + + const char* dst_buffer; + size_t dst_buffer_byte_size; + TRITONSERVER_MemoryType dst_memory_type; + int64_t dst_memory_type_id; + + // Batch inputs are always created on CPU + RESPOND_ALL_AND_SET_NULL_IF_ERROR( + (*responses), responses->size(), + collector->ProcessBatchInput( + batch_input, nullptr, 0, {{TRITONSERVER_MEMORY_CPU, 0}}, + &dst_buffer, &dst_buffer_byte_size, &dst_memory_type, + &dst_memory_type_id)); + + const auto torch_dtype = + ConvertDataTypeToTorchType(batch_input.DataType()); + + torch::Tensor input_tensor = torch::from_blob( + const_cast(dst_buffer), shape, + updated_options.dtype(torch_dtype.second)); + (*input_tensors)[input_index_map_[input_name]] = input_tensor; + } + } + // Finalize... *cuda_copy |= collector->Finalize(); @@ -1887,9 +1952,11 @@ ModelInstanceState::ReadOutputTensors( // Output tensors may not reside on the same device as model torch::Device tensor_device = output_flat.device(); - const auto memory_type = (tensor_device.type() == torch::kCPU) ? TRITONSERVER_MEMORY_CPU - : TRITONSERVER_MEMORY_GPU; - const auto memory_id = (tensor_device.type() == torch::kCPU) ? 0 : tensor_device.index(); + const auto memory_type = (tensor_device.type() == torch::kCPU) + ? TRITONSERVER_MEMORY_CPU + : TRITONSERVER_MEMORY_GPU; + const auto memory_id = + (tensor_device.type() == torch::kCPU) ? 0 : tensor_device.index(); // Batch output doesn't support string data type yet, as it is not trivial // to parse string output @@ -1906,16 +1973,16 @@ ModelInstanceState::ReadOutputTensors( return TRITONSERVER_ErrorNew( TRITONSERVER_ERROR_INVALID_ARG, (std::string("output '") + name + - "' is a scalar which is not supported.") + "' is a scalar which is not supported.") .c_str()); } responder.ProcessTensor( - name, output_dtype, batchn_shape, output_buffer, - memory_type, memory_id); + name, output_dtype, batchn_shape, output_buffer, memory_type, + memory_id); } else { responder.ProcessBatchOutput( - name, *batch_output, output_buffer, memory_type, memory_id); + name, *batch_output, output_buffer, memory_type, memory_id); } } else if (output_tensors[op_index].isList()) { // Custom handling for string/bytes tensor... diff --git a/src/libtorch_utils.cc b/src/libtorch_utils.cc index 699c742..49c13aa 100644 --- a/src/libtorch_utils.cc +++ b/src/libtorch_utils.cc @@ -152,7 +152,7 @@ ParseParameter( #ifdef TRITON_ENABLE_GPU TRITONSERVER_Error* ConvertCUDAStatusToTritonError( - cudaError_t cuda_error,TRITONSERVER_Error_Code code, const char* msg) + cudaError_t cuda_error, TRITONSERVER_Error_Code code, const char* msg) { if (cuda_error != cudaSuccess) { return TRITONSERVER_ErrorNew(