diff --git a/build.py b/build.py index f35546882d..8fdf953d3a 100755 --- a/build.py +++ b/build.py @@ -1011,6 +1011,7 @@ def create_dockerfile_linux(ddir, dockerfile_name, argmap, backends, repoagents, if 'sagemaker' in endpoints: df += ''' LABEL com.amazonaws.sagemaker.capabilities.accept-bind-to-port=true +LABEL com.amazonaws.sagemaker.capabilities.multi-models=true COPY --chown=1000:1000 docker/sagemaker/serve /usr/bin/. ''' diff --git a/docker/sagemaker/serve b/docker/sagemaker/serve index c92e3bb407..e487f9af45 100755 --- a/docker/sagemaker/serve +++ b/docker/sagemaker/serve @@ -26,7 +26,20 @@ # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. SAGEMAKER_SINGLE_MODEL_REPO=/opt/ml/model/ -SAGEMAKER_ARGS="--model-repository=${SAGEMAKER_SINGLE_MODEL_REPO}" +SAGEMAKER_MULTI_MODEL_REPO=/opt/ml/models/ + +SAGEMAKER_MODEL_REPO=${SAGEMAKER_SINGLE_MODEL_REPO} +is_mme_mode=false + +if [ -n "$SAGEMAKER_MULTI_MODEL" ]; then + if [ "$SAGEMAKER_MULTI_MODEL" == "true" ]; then + SAGEMAKER_MODEL_REPO=${SAGEMAKER_MULTI_MODEL_REPO} + is_mme_mode=true + echo "Triton is running in SageMaker MME mode" + fi +fi + +SAGEMAKER_ARGS="--model-repository=${SAGEMAKER_MODEL_REPO}" if [ -n "$SAGEMAKER_BIND_TO_PORT" ]; then SAGEMAKER_ARGS="${SAGEMAKER_ARGS} --sagemaker-port=${SAGEMAKER_BIND_TO_PORT}" fi @@ -51,22 +64,28 @@ fi if [ -n "$SAGEMAKER_TRITON_LOG_ERROR" ]; then SAGEMAKER_ARGS="${SAGEMAKER_ARGS} --log-error=${SAGEMAKER_TRITON_LOG_ERROR}" fi +if [ -n "$SAGEMAKER_TRITON_SHM_DEFAULT_BYTE_SIZE" ]; then + SAGEMAKER_ARGS="${SAGEMAKER_ARGS} --backend-config=python,shm-default-byte-size=${SAGEMAKER_TRITON_SHM_DEFAULT_BYTE_SIZE}" +fi +if [ -n "$SAGEMAKER_TRITON_SHM_GROWTH_BYTE_SIZE" ]; then + SAGEMAKER_ARGS="${SAGEMAKER_ARGS} --backend-config=python,shm-growth-byte-size=${SAGEMAKER_TRITON_SHM_GROWTH_BYTE_SIZE}" +fi -if [ -f "${SAGEMAKER_SINGLE_MODEL_REPO}/config.pbtxt" ]; then +if [ "${is_mme_mode}" = false ] && [ -f "${SAGEMAKER_MODEL_REPO}/config.pbtxt" ]; then echo "ERROR: Incorrect directory structure." echo " Model directory needs to contain the top level folder" exit 1 fi -if [ -n "$SAGEMAKER_TRITON_DEFAULT_MODEL_NAME" ]; then - if [ -d "${SAGEMAKER_SINGLE_MODEL_REPO}/$SAGEMAKER_TRITON_DEFAULT_MODEL_NAME" ]; then +if [ "${is_mme_mode}" = false ] && [ -n "$SAGEMAKER_TRITON_DEFAULT_MODEL_NAME" ]; then + if [ -d "${SAGEMAKER_MODEL_REPO}/$SAGEMAKER_TRITON_DEFAULT_MODEL_NAME" ]; then SAGEMAKER_ARGS="${SAGEMAKER_ARGS} --load-model=${SAGEMAKER_TRITON_DEFAULT_MODEL_NAME}" else echo "ERROR: Directory with provided SAGEMAKER_TRITON_DEFAULT_MODEL_NAME ${SAGEMAKER_TRITON_DEFAULT_MODEL_NAME} does not exist" exit 1 fi -else - MODEL_DIRS=(`find "${SAGEMAKER_SINGLE_MODEL_REPO}" -mindepth 1 -maxdepth 1 -type d -printf "%f\n"`) +elif [ "${is_mme_mode}" = false ]; then + MODEL_DIRS=(`find "${SAGEMAKER_MODEL_REPO}" -mindepth 1 -maxdepth 1 -type d -printf "%f\n"`) case ${#MODEL_DIRS[@]} in 0) echo "ERROR: No model found in model repository"; exit 1 diff --git a/qa/L0_sagemaker/sagemaker_multi_model_test.py b/qa/L0_sagemaker/sagemaker_multi_model_test.py new file mode 100644 index 0000000000..820562c1da --- /dev/null +++ b/qa/L0_sagemaker/sagemaker_multi_model_test.py @@ -0,0 +1,222 @@ +#!/usr/bin/python +# Copyright (c) 2021-2022, NVIDIA CORPORATION. All rights reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions +# are met: +# * Redistributions of source code must retain the above copyright +# notice, this list of conditions and the following disclaimer. +# * Redistributions in binary form must reproduce the above copyright +# notice, this list of conditions and the following disclaimer in the +# documentation and/or other materials provided with the distribution. +# * Neither the name of NVIDIA CORPORATION nor the names of its +# contributors may be used to endorse or promote products derived +# from this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY +# EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR +# PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR +# CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, +# EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, +# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR +# PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY +# OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +import sys + +sys.path.append("../common") + +import os +import shutil +import time +import unittest +import numpy as np +import infer_util as iu +import test_util as tu +import tritonclient.http as httpclient + +import argparse +import csv +import json +import os +import requests +import socket +import sys + + +class SageMakerMultiModelTest(tu.TestResultCollector): + def setUp(self): + + SAGEMAKER_BIND_TO_PORT = os.getenv("SAGEMAKER_BIND_TO_PORT", "8080") + self.url_mme_ = "http://localhost:{}/models".format(SAGEMAKER_BIND_TO_PORT) + + # model_1 setup + self.model1_name = "sm_mme_model_1" + self.model1_url = "/opt/ml/models/123456789abcdefghi/model" + + self.model1_input_data_ = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15] + self.model1_expected_output0_data_ = [0, 2, 4, 6, 8, 10, 12, 14, 16, 18, 20, 22, 24, 26, 28, 30] + self.model1_expected_output1_data_ = [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0] + + self.model1_expected_result_ = { + "model_name": "sm_mme_model_1", + "model_version": "1", + "outputs": [ + {"name": "OUTPUT0", "datatype": "INT32", "shape": [1, 16], "data": self.model1_expected_output0_data_}, + {"name": "OUTPUT1", "datatype": "INT32", "shape": [1, 16], "data": self.model1_expected_output1_data_}, + ], + } + + # model_2 setup + self.model2_name = "sm_mme_model_2" + self.model2_url = "/opt/ml/models/987654321ihgfedcba/model" + + # Output is same as input since this is an identity model + self.model2_input_data_ = [0, 1, 2, 3, 4, 5, 6, 7] + + def test_sm_0_environment_variables_set(self): + self.assertEqual( + os.getenv("SAGEMAKER_MULTI_MODEL"), "true", "Variable SAGEMAKER_MULTI_MODEL must be set to true" + ) + + def test_sm_1_model_load(self): + # Load model_1 + request_body = {"model_name": self.model1_name, "url": self.model1_url} + headers = {"Content-Type": "application/json"} + r = requests.post(self.url_mme_, data=json.dumps(request_body), headers=headers) + time.sleep(5) # wait for model to load + self.assertEqual(r.status_code, 200, "Expected status code 200, received {}".format(r.status_code)) + + # Load the same model again, expect a 409 + request_body = {"model_name": self.model1_name, "url": self.model1_url} + headers = {"Content-Type": "application/json"} + r = requests.post(self.url_mme_, data=json.dumps(request_body), headers=headers) + time.sleep(5) # wait for model to load + self.assertEqual(r.status_code, 409, "Expected status code 409, received {}".format(r.status_code)) + + # Load model_2 + request_body = {"model_name": self.model2_name, "url": self.model2_url} + headers = {"Content-Type": "application/json"} + r = requests.post(self.url_mme_, data=json.dumps(request_body), headers=headers) + time.sleep(5) # wait for model to load + self.assertEqual(r.status_code, 200, "Expected status code 200, received {}".format(r.status_code)) + + def test_sm_2_model_list(self): + r = requests.get(self.url_mme_) + time.sleep(3) + expected_response_1 = { + "models": [ + {"modelName": self.model1_name, "modelUrl": self.model1_url}, + {"modelName": self.model2_name, "modelUrl": self.model2_url}, + ] + } + expected_response_2 = { + "models": [ + {"modelName": self.model2_name, "modelUrl": self.model2_url}, + {"modelName": self.model1_name, "modelUrl": self.model1_url}, + ] + } + + # Returned list response's order is not deterministic + self.assertIn( + r.json(), + [expected_response_1, expected_response_2], + "Expected one of {}, received: {}".format([expected_response_1, expected_response_2], r.json()), + ) + + def test_sm_3_model_get(self): + get_url = "{}/{}".format(self.url_mme_, self.model1_name) + r = requests.get(get_url) + time.sleep(3) + expected_response = {"modelName": self.model1_name, "modelUrl": self.model1_url} + self.assertEqual( + r.json(), expected_response, "Expected response: {}, received: {}".format(expected_response, r.json()) + ) + + def test_sm_4_model_invoke(self): + # Invoke model_1 + inputs = [] + outputs = [] + inputs.append(httpclient.InferInput("INPUT0", [1, 16], "INT32")) + inputs.append(httpclient.InferInput("INPUT1", [1, 16], "INT32")) + + # Initialize the data + input_data = np.array(self.model1_input_data_, dtype=np.int32) + input_data = np.expand_dims(input_data, axis=0) + inputs[0].set_data_from_numpy(input_data, binary_data=False) + inputs[1].set_data_from_numpy(input_data, binary_data=False) + + outputs.append(httpclient.InferRequestedOutput("OUTPUT0", binary_data=False)) + outputs.append(httpclient.InferRequestedOutput("OUTPUT1", binary_data=False)) + request_body, _ = httpclient.InferenceServerClient.generate_request_body(inputs, outputs=outputs) + + headers = {"Content-Type": "application/json"} + invoke_url = "{}/{}/invoke".format(self.url_mme_, self.model1_name) + r = requests.post(invoke_url, data=request_body, headers=headers) + r.raise_for_status() + + self.assertEqual( + self.model1_expected_result_, + r.json(), + "Expected response : {}, received: {}".format(self.model1_expected_result_, r.json()), + ) + + # Invoke model_2 + inputs = [] + outputs = [] + inputs.append( + httpclient.InferInput( + "INPUT0", + [1, 8], + "FP32", + ) + ) + input_data = np.array(self.model2_input_data_, dtype=np.float32) + input_data = np.expand_dims(input_data, axis=0) + inputs[0].set_data_from_numpy(input_data, binary_data=True) + + outputs.append(httpclient.InferRequestedOutput("OUTPUT0", binary_data=True)) + + request_body, header_length = httpclient.InferenceServerClient.generate_request_body(inputs, outputs=outputs) + + invoke_url = "{}/{}/invoke".format(self.url_mme_, self.model2_name) + headers = { + "Content-Type": "application/vnd.sagemaker-triton.binary+json;json-header-size={}".format(header_length) + } + r = requests.post(invoke_url, data=request_body, headers=headers) + + header_length_prefix = "application/vnd.sagemaker-triton.binary+json;json-header-size=" + header_length_str = r.headers["Content-Type"][len(header_length_prefix) :] + result = httpclient.InferenceServerClient.parse_response_body(r._content, header_length=int(header_length_str)) + + # Get the inference header size so we can locate the output binary data + output_data = result.as_numpy("OUTPUT0") + + for i in range(8): + self.assertEqual(output_data[0][i], input_data[0][i], "Tensor Value Mismatch") + + def test_sm_5_model_unload(self): + # Unload model_1 + unload_url = "{}/{}".format(self.url_mme_, self.model1_name) + r = requests.delete(unload_url) + time.sleep(3) + self.assertEqual(r.status_code, 200, "Expected status code 200, received {}".format(r.status_code)) + + # Unload model_2 + unload_url = "{}/{}".format(self.url_mme_, self.model2_name) + r = requests.delete(unload_url) + time.sleep(3) + self.assertEqual(r.status_code, 200, "Expected status code 200, received {}".format(r.status_code)) + + # Unload a non-loaded model, expect a 404 + unload_url = "{}/sm_non_loaded_model".format(self.url_mme_) + r = requests.delete(unload_url) + time.sleep(3) + self.assertEqual(r.status_code, 404, "Expected status code 404, received {}".format(r.status_code)) + + +if __name__ == "__main__": + unittest.main() diff --git a/qa/L0_sagemaker/test.sh b/qa/L0_sagemaker/test.sh old mode 100644 new mode 100755 index 4c768b1e5e..e701e8dd71 --- a/qa/L0_sagemaker/test.sh +++ b/qa/L0_sagemaker/test.sh @@ -1,5 +1,5 @@ #!/bin/bash -# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. +# Copyright (c) 2021-2022, NVIDIA CORPORATION. All rights reserved. # # Redistribution and use in source and binary forms, with or without # modification, are permitted provided that the following conditions @@ -55,6 +55,8 @@ rm -f *.log rm -f *.out SAGEMAKER_TEST=sagemaker_test.py +SAGEMAKER_MULTI_MODEL_TEST=sagemaker_multi_model_test.py +MULTI_MODEL_UNIT_TEST_COUNT=6 UNIT_TEST_COUNT=9 CLIENT_LOG="./client.log" @@ -363,6 +365,68 @@ fi kill $SERVER_PID wait $SERVER_PID +# MME begin +# Prepare model repository + +ln -s `pwd`/models /opt/ml/models +# Model path will be of the form /opt/ml/models//model +MODEL1_PATH="models/123456789abcdefghi/model" +MODEL2_PATH="models/987654321ihgfedcba/model" +mkdir -p "${MODEL1_PATH}" +mkdir -p "${MODEL2_PATH}" + +cp -r $DATADIR/qa_model_repository/onnx_int32_int32_int32/* ${MODEL1_PATH} && \ + rm -r ${MODEL1_PATH}/2 && rm -r ${MODEL1_PATH}/3 && \ + sed -i "s/onnx_int32_int32_int32/sm_mme_model_1/" ${MODEL1_PATH}/config.pbtxt + +cp -r $DATADIR/qa_identity_model_repository/onnx_zero_1_float32/* ${MODEL2_PATH} && \ + sed -i "s/onnx_zero_1_float32/sm_mme_model_2/" ${MODEL2_PATH}/config.pbtxt + +# Start server with 'serve' script +export SAGEMAKER_MULTI_MODEL=true +export SAGEMAKER_TRITON_LOG_VERBOSE=true + +serve > $SERVER_LOG 2>&1 & +SERVE_PID=$! +# Obtain Triton PID in such way as $! will return the script PID +sleep 1 +SERVER_PID=`ps | grep tritonserver | awk '{ printf $1 }'` +sagemaker_wait_for_server_ready $SERVER_PID 10 +if [ "$WAIT_RET" != "0" ]; then + echo -e "\n***\n*** Failed to start $SERVER\n***" + kill $SERVER_PID || true + cat $SERVER_LOG + exit 1 +fi + +# API tests in default setting +set +e +python $SAGEMAKER_MULTI_MODEL_TEST SageMakerMultiModelTest >>$CLIENT_LOG 2>&1 +if [ $? -ne 0 ]; then + echo -e "\n***\n*** Test Failed\n***" + cat $CLIENT_LOG + RET=1 +else + check_test_results $TEST_RESULT_FILE $MULTI_MODEL_UNIT_TEST_COUNT + if [ $? -ne 0 ]; then + cat $CLIENT_LOG + echo -e "\n***\n*** Test Result Verification Failed\n***" + RET=1 + fi +fi +set -e + +unset SAGEMAKER_MULTI_MODEL + +unlink /opt/ml/models +rm -rf /opt/ml/models + +kill $SERVER_PID +wait $SERVE_PID + +# MME end + + unlink /opt/ml/model rm -rf /opt/ml/model diff --git a/src/sagemaker_server.cc b/src/sagemaker_server.cc index 1b0d68d736..bc1f7dffd7 100644 --- a/src/sagemaker_server.cc +++ b/src/sagemaker_server.cc @@ -27,6 +27,93 @@ namespace triton { namespace server { +#define HTTP_RESPOND_IF_ERR(REQ, X) \ + do { \ + TRITONSERVER_Error* err__ = (X); \ + if (err__ != nullptr) { \ + EVBufferAddErrorJson((REQ)->buffer_out, err__); \ + evhtp_send_reply((REQ), EVHTP_RES_BADREQ); \ + TRITONSERVER_ErrorDelete(err__); \ + return; \ + } \ + } while (false) + +namespace { + +void +EVBufferAddErrorJson(evbuffer* buffer, TRITONSERVER_Error* err) +{ + const char* message = TRITONSERVER_ErrorMessage(err); + + triton::common::TritonJson::Value response( + triton::common::TritonJson::ValueType::OBJECT); + response.AddStringRef("error", message, strlen(message)); + + triton::common::TritonJson::WriteBuffer buffer_json; + response.Write(&buffer_json); + + evbuffer_add(buffer, buffer_json.Base(), buffer_json.Size()); +} + +TRITONSERVER_Error* +EVBufferToJson( + triton::common::TritonJson::Value* document, evbuffer_iovec* v, int* v_idx, + const size_t length, int n) +{ + size_t offset = 0, remaining_length = length; + char* json_base; + std::vector json_buffer; + + // No need to memcpy when number of iovecs is 1 + if ((n > 0) && (v[0].iov_len >= remaining_length)) { + json_base = static_cast(v[0].iov_base); + if (v[0].iov_len > remaining_length) { + v[0].iov_base = static_cast(json_base + remaining_length); + v[0].iov_len -= remaining_length; + remaining_length = 0; + } else if (v[0].iov_len == remaining_length) { + remaining_length = 0; + *v_idx += 1; + } + } else { + json_buffer.resize(length); + json_base = json_buffer.data(); + while ((remaining_length > 0) && (*v_idx < n)) { + char* base = static_cast(v[*v_idx].iov_base); + size_t base_size; + if (v[*v_idx].iov_len > remaining_length) { + base_size = remaining_length; + v[*v_idx].iov_base = static_cast(base + remaining_length); + v[*v_idx].iov_len -= remaining_length; + remaining_length = 0; + } else { + base_size = v[*v_idx].iov_len; + remaining_length -= v[*v_idx].iov_len; + *v_idx += 1; + } + + memcpy(json_base + offset, base, base_size); + offset += base_size; + } + } + + if (remaining_length != 0) { + return TRITONSERVER_ErrorNew( + TRITONSERVER_ERROR_INVALID_ARG, + std::string( + "unexpected size for request JSON, expecting " + + std::to_string(remaining_length) + " more bytes") + .c_str()); + } + + RETURN_IF_ERR(document->Parse(json_base, length)); + + return nullptr; // success +} + +} // namespace + + const std::string SagemakerAPIServer::binary_mime_type_( "application/vnd.sagemaker-triton.binary+json;json-header-size="); @@ -113,6 +200,63 @@ SagemakerAPIServer::Handle(evhtp_request_t* req) return; } + std::string multi_model_name, action; + if (RE2::FullMatch( + std::string(req->uri->path->full), models_regex_, &multi_model_name, + &action)) { + switch (req->method) { + case htp_method_GET: + if (multi_model_name.empty()) { + LOG_VERBOSE(1) << "SageMaker request: LIST ALL MODELS"; + + SageMakerMMEListModel(req); + return; + } else { + LOG_VERBOSE(1) << "SageMaker request: GET MODEL"; + + SageMakerMMEGetModel(req, multi_model_name.c_str()); + return; + } + case htp_method_POST: + if (action == "/invoke") { + LOG_VERBOSE(1) << "SageMaker request: INVOKE MODEL"; + + if (sagemaker_models_list_.find(multi_model_name.c_str()) == + sagemaker_models_list_.end()) { + evhtp_send_reply(req, EVHTP_RES_NOTFOUND); /* 404*/ + return; + } + + HandleInfer(req, multi_model_name, model_version_str_); + return; + } + if (action.empty()) { + LOG_VERBOSE(1) << "SageMaker request: LOAD MODEL"; + + std::unordered_map parse_load_map; + ParseSageMakerRequest(req, &parse_load_map, "load"); + SageMakerMMELoadModel(req, parse_load_map); + return; + } + break; + case htp_method_DELETE: { + // UNLOAD MODEL + LOG_VERBOSE(1) << "SageMaker request: UNLOAD MODEL"; + req->method = htp_method_POST; + + SageMakerMMEUnloadModel(req, multi_model_name.c_str()); + + return; + } + default: + LOG_VERBOSE(1) << "SageMaker error: " << req->method << " " + << req->uri->path->full << " - " + << static_cast(EVHTP_RES_BADREQ); + evhtp_send_reply(req, EVHTP_RES_BADREQ); + return; + } + } + LOG_VERBOSE(1) << "SageMaker error: " << req->method << " " << req->uri->path->full << " - " << static_cast(EVHTP_RES_BADREQ); @@ -138,4 +282,283 @@ SagemakerAPIServer::Create( return nullptr; } + +void +SagemakerAPIServer::ParseSageMakerRequest( + evhtp_request_t* req, + std::unordered_map* parse_map, + const std::string& action) +{ + struct evbuffer_iovec* v = nullptr; + int v_idx = 0; + int n = evbuffer_peek(req->buffer_in, -1, NULL, NULL, 0); + if (n > 0) { + v = static_cast( + alloca(sizeof(struct evbuffer_iovec) * n)); + if (evbuffer_peek(req->buffer_in, -1, NULL, v, n) != n) { + HTTP_RESPOND_IF_ERR( + req, TRITONSERVER_ErrorNew( + TRITONSERVER_ERROR_INTERNAL, + "unexpected error getting load model request buffers")); + } + } + + std::string model_name_string; + std::string url_string; + + size_t buffer_len = evbuffer_get_length(req->buffer_in); + if (buffer_len > 0) { + triton::common::TritonJson::Value request; + HTTP_RESPOND_IF_ERR( + req, EVBufferToJson(&request, v, &v_idx, buffer_len, n)); + + triton::common::TritonJson::Value url; + triton::common::TritonJson::Value model_name; + + if (request.Find("model_name", &model_name)) { + HTTP_RESPOND_IF_ERR(req, model_name.AsString(&model_name_string)); + LOG_VERBOSE(1) << "Received model_name: " << model_name_string.c_str(); + } + + if ((action == "load") && (request.Find("url", &url))) { + HTTP_RESPOND_IF_ERR(req, url.AsString(&url_string)); + LOG_VERBOSE(1) << "Received url: " << url_string.c_str(); + } + } + + if (action == "load") { + (*parse_map)["url"] = url_string.c_str(); + } + (*parse_map)["model_name"] = model_name_string.c_str(); + + return; +} + +void +SagemakerAPIServer::SageMakerMMEUnloadModel( + evhtp_request_t* req, const char* model_name) +{ + std::lock_guard lock(mutex_); + + if (sagemaker_models_list_.find(model_name) == sagemaker_models_list_.end()) { + LOG_VERBOSE(1) << "Model " << model_name << "is not loaded." << std::endl; + evhtp_send_reply(req, EVHTP_RES_NOTFOUND); /* 404*/ + return; + } + + HandleRepositoryControl(req, "", model_name, "unload"); + + std::string repo_path = sagemaker_models_list_.at(model_name); + + std::string repo_parent_path, subdir, customer_subdir; + RE2::FullMatch( + repo_path, model_path_regex_, &repo_parent_path, &subdir, + &customer_subdir); + + TRITONSERVER_Error* unload_err = TRITONSERVER_ServerUnregisterModelRepository( + server_.get(), repo_parent_path.c_str()); + + if (unload_err != nullptr) { + EVBufferAddErrorJson(req->buffer_out, unload_err); + evhtp_send_reply(req, EVHTP_RES_BADREQ); + TRITONSERVER_ErrorDelete(unload_err); + } + + sagemaker_models_list_.erase(model_name); +} + +void +SagemakerAPIServer::SageMakerMMEGetModel( + evhtp_request_t* req, const char* model_name) +{ + if (sagemaker_models_list_.find(model_name) == sagemaker_models_list_.end()) { + evhtp_send_reply(req, EVHTP_RES_NOTFOUND); /* 404*/ + return; + } + + triton::common::TritonJson::Value sagemaker_get_json( + triton::common::TritonJson::ValueType::OBJECT); + + sagemaker_get_json.AddString("modelName", model_name); + sagemaker_get_json.AddString( + "modelUrl", sagemaker_models_list_.at(model_name)); + + const char* buffer; + size_t byte_size; + + triton::common::TritonJson::WriteBuffer json_buffer_; + json_buffer_.Clear(); + sagemaker_get_json.Write(&json_buffer_); + + byte_size = json_buffer_.Size(); + buffer = json_buffer_.Base(); + + evbuffer_add(req->buffer_out, buffer, byte_size); + evhtp_send_reply(req, EVHTP_RES_OK); +} + +void +SagemakerAPIServer::SageMakerMMEListModel(evhtp_request_t* req) +{ + triton::common::TritonJson::Value sagemaker_list_json( + triton::common::TritonJson::ValueType::OBJECT); + + triton::common::TritonJson::Value models_array( + sagemaker_list_json, triton::common::TritonJson::ValueType::ARRAY); + + for (auto it = sagemaker_models_list_.begin(); + it != sagemaker_models_list_.end(); it++) { + triton::common::TritonJson::Value model_url_pair( + models_array, triton::common::TritonJson::ValueType::OBJECT); + + bool ready = false; + TRITONSERVER_ServerModelIsReady( + server_.get(), it->first.c_str(), 1, &ready); + + /* Add to return list only if model is ready to be served */ + if (ready) { + model_url_pair.AddString("modelName", it->first); + model_url_pair.AddString("modelUrl", it->second); + } + + models_array.Append(std::move(model_url_pair)); + } + + sagemaker_list_json.Add("models", std::move(models_array)); + + const char* buffer; + size_t byte_size; + + triton::common::TritonJson::WriteBuffer json_buffer_; + json_buffer_.Clear(); + sagemaker_list_json.Write(&json_buffer_); + + byte_size = json_buffer_.Size(); + buffer = json_buffer_.Base(); + + evbuffer_add(req->buffer_out, buffer, byte_size); + evhtp_send_reply(req, EVHTP_RES_OK); +} + +void +SagemakerAPIServer::SageMakerMMELoadModel( + evhtp_request_t* req, + const std::unordered_map parse_map) +{ + std::string repo_path = parse_map.at("url"); + std::string model_name = parse_map.at("model_name"); + + /* Error out if there's more than one subdir/version within + * supplied model repo, as ensemble in MME is not (currently) + * supported + */ + DIR* dir; + struct dirent* ent; + int dir_count = 0; + if ((dir = opendir(repo_path.c_str())) != NULL) { + while ((ent = readdir(dir)) != NULL) { + if ((ent->d_type == DT_DIR) && (strcmp(ent->d_name, ".") == 0) && + (strcmp(ent->d_name, "..") == 0)) { + dir_count += 1; + } + if (dir_count > 1) { + HTTP_RESPOND_IF_ERR( + req, + TRITONSERVER_ErrorNew( + TRITONSERVER_ERROR_INTERNAL, + "More than one version or model directories found. Note that " + "Ensemble models are not supported in SageMaker MME mode.")); + closedir(dir); + return; + } + } + closedir(dir); + } + + std::vector subdir_modelname_map; + + /* Split repo path into three parts: + * /opt/ml/models//model/optional_customer_subdir + * 1st repo_parent_path: /opt/ml/models/ + * 2nd subdir: model + * 3rd customer_subdir: optional_customer_subdir + */ + + std::string repo_parent_path, subdir, customer_subdir; + RE2::FullMatch( + repo_path, model_path_regex_, &repo_parent_path, &subdir, + &customer_subdir); + + auto param = TRITONSERVER_ParameterNew( + subdir.c_str(), TRITONSERVER_PARAMETER_STRING, model_name.c_str()); + + if (param != nullptr) { + subdir_modelname_map.emplace_back(param); + } else { + HTTP_RESPOND_IF_ERR( + req, TRITONSERVER_ErrorNew( + TRITONSERVER_ERROR_INTERNAL, + "unexpected error on creating Triton parameter")); + } + + /* Register repository with model mapping */ + TRITONSERVER_Error* err = nullptr; + err = TRITONSERVER_ServerRegisterModelRepository( + server_.get(), repo_parent_path.c_str(), subdir_modelname_map.data(), + subdir_modelname_map.size()); + + TRITONSERVER_ParameterDelete(param); + + // If a model_name is reused i.e. model_name is already mapped, return a 409 + if ((err != nullptr) && + (TRITONSERVER_ErrorCode(err) == TRITONSERVER_ERROR_ALREADY_EXISTS)) { + EVBufferAddErrorJson(req->buffer_out, err); + evhtp_send_reply(req, EVHTP_RES_CONFLICT); /* 409 */ + TRITONSERVER_ErrorDelete(err); + return; + } else if (err != nullptr) { + EVBufferAddErrorJson(req->buffer_out, err); + evhtp_send_reply(req, EVHTP_RES_BADREQ); + TRITONSERVER_ErrorDelete(err); + return; + } + + err = TRITONSERVER_ServerLoadModel(server_.get(), model_name.c_str()); + + /* Unlikely after duplicate repo check, but in case Load Model also returns + * ALREADY_EXISTS error */ + if ((err != nullptr) && + (TRITONSERVER_ErrorCode(err) == TRITONSERVER_ERROR_ALREADY_EXISTS)) { + EVBufferAddErrorJson(req->buffer_out, err); + evhtp_send_reply(req, EVHTP_RES_CONFLICT); /* 409 */ + TRITONSERVER_ErrorDelete(err); + return; + } else if (err != nullptr) { + EVBufferAddErrorJson(req->buffer_out, err); + evhtp_send_reply(req, EVHTP_RES_BADREQ); + } else { + std::lock_guard lock(mutex_); + + sagemaker_models_list_.emplace(model_name, repo_path); + evhtp_send_reply(req, EVHTP_RES_OK); + } + + /* Unregister model repository in case of load failure*/ + if (err != nullptr) { + err = TRITONSERVER_ServerUnregisterModelRepository( + server_.get(), repo_parent_path.c_str()); + LOG_VERBOSE(1) + << "Unregistered model repository due to load failure for model: " + << model_name << std::endl; + } + + if (err != nullptr) { + EVBufferAddErrorJson(req->buffer_out, err); + evhtp_send_reply(req, EVHTP_RES_BADREQ); + TRITONSERVER_ErrorDelete(err); + } + + return; +} + }} // namespace triton::server diff --git a/src/sagemaker_server.h b/src/sagemaker_server.h index fbb472dd1f..a20ab239a4 100644 --- a/src/sagemaker_server.h +++ b/src/sagemaker_server.h @@ -25,9 +25,12 @@ // OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. #pragma once -#include "http_server.h" +#include #include "common.h" +#include "dirent.h" +#include "http_server.h" +#include "triton/core/tritonserver.h" namespace triton { namespace server { @@ -63,6 +66,9 @@ class SagemakerAPIServer : public HTTPAPIServer { : HTTPAPIServer( server, trace_manager, shm_manager, port, address, thread_cnt), ping_regex_(R"(/ping)"), invocations_regex_(R"(/invocations)"), + models_regex_(R"(/models(?:/)?([^/]+)?(/invoke)?)"), + model_path_regex_( + R"((\/opt\/ml\/models\/[0-9A-Za-z._]+)\/(model)\/?([0-9A-Za-z._]+)?)"), ping_mode_("ready"), model_name_(GetEnvironmentVariableOrDefault( "SAGEMAKER_TRITON_DEFAULT_MODEL_NAME", @@ -71,6 +77,21 @@ class SagemakerAPIServer : public HTTPAPIServer { { } + void ParseSageMakerRequest( + evhtp_request_t* req, + std::unordered_map* parse_map, + const std::string& action); + + void SageMakerMMELoadModel( + evhtp_request_t* req, + const std::unordered_map parse_map); + + void SageMakerMMEUnloadModel(evhtp_request_t* req, const char* model_name); + + void SageMakerMMEListModel(evhtp_request_t* req); + + void SageMakerMMEGetModel(evhtp_request_t* req, const char* model_name); + void Handle(evhtp_request_t* req) override; std::unique_ptr CreateInferRequest( @@ -83,6 +104,7 @@ class SagemakerAPIServer : public HTTPAPIServer { evhtp_request_t* req, int32_t content_length, size_t* header_length) override; + // Currently the compresssion schema hasn't been defined, // assume identity compression type is used for both request and response DataCompressor::Type GetRequestCompressionType(evhtp_request_t* req) override @@ -95,14 +117,23 @@ class SagemakerAPIServer : public HTTPAPIServer { } re2::RE2 ping_regex_; re2::RE2 invocations_regex_; + re2::RE2 models_regex_; + re2::RE2 model_path_regex_; const std::string ping_mode_; - // For single model mode, assume that only one version of "model" is presented + /* For single model mode, assume that only one version of "model" is presented + */ const std::string model_name_; const std::string model_version_str_; static const std::string binary_mime_type_; + + /* Maintain list of loaded models */ + std::unordered_map sagemaker_models_list_; + + /* Mutex to handle concurrent updates */ + std::mutex mutex_; }; }} // namespace triton::server