Skip to content

Commit

Permalink
Add new REST API to get status of models known to the ModelServer.
Browse files Browse the repository at this point in the history
The API is accessible via GET method on following URL:

  http://host:port/v1/models/${MODEL_NAME}[/versions/${MODEL_VERSION}]

This closely mirrors `GetModelService.GetModelStatus` gRPC API.
If no ${MODEL_VERSION} is specified, status for all versions of
the model are returned.

PiperOrigin-RevId: 210554222
  • Loading branch information
netfs authored and tensorflower-gardener committed Aug 28, 2018
1 parent c0c6648 commit 00e459f
Show file tree
Hide file tree
Showing 8 changed files with 221 additions and 65 deletions.
3 changes: 2 additions & 1 deletion tensorflow_serving/apis/get_model_status.proto
Expand Up @@ -63,5 +63,6 @@ message ModelVersionStatus {
// Response for ModelStatusRequest on successful run.
message GetModelStatusResponse {
// Version number and status information for applicable model version(s).
repeated ModelVersionStatus model_version_status = 1;
repeated ModelVersionStatus model_version_status = 1
[json_name = "model_version_status"];
}
100 changes: 70 additions & 30 deletions tensorflow_serving/g3doc/api_rest.md
@@ -1,33 +1,9 @@
# RESTful API

In addition to [gRPC
APIs](https://github.com/tensorflow/serving/blob/master/tensorflow_serving/apis/prediction_service.proto)
TensorFlow ModelServer also supports RESTful APIs for classification, regression
and prediction on TensorFlow models. This page describes these API endpoints and
format of request/response involved in using them.

TensorFlow ModelServer running on `host:port` accepts following REST API
requests:

```
POST http://host:port/<URI>:<VERB>
URI: /v1/models/${MODEL_NAME}[/versions/${MODEL_VERSION}]
VERB: classify|regress|predict
```

`/versions/${MODEL_VERSION}` is optional. If omitted the latest version is used.

This API closely follows the gRPC version of
[`PredictionService`](https://github.com/tensorflow/serving/blob/5369880e9143aa00d586ee536c12b04e945a977c/tensorflow_serving/apis/prediction_service.proto#L15)
API.

Examples of request URLs:

```
http://host:port/v1/models/iris:classify
http://host:port/v1/models/mnist/versions/314:predict
```
In addition to
[gRPC APIs](https://github.com/tensorflow/serving/blob/master/tensorflow_serving/apis/prediction_service.proto)
TensorFlow ModelServer also supports RESTful APIs. This page describes these API
endpoints and an end-to-end [example](#example) on usage.

The request and response is a JSON object. The composition of this object
depends on the request type or verb. See the API specific sections below for
Expand All @@ -42,8 +18,41 @@ In case of error, all APIs will return a JSON object in the response body with
}
```

## Model status API

This API closely follows the
[`ModelService.GetModelStatus`](https://github.com/tensorflow/serving/blob/5369880e9143aa00d586ee536c12b04e945a977c/tensorflow_serving/apis/model_service.proto#L17)
gRPC API. It returns the status of a model in the ModelServer.

### URL

```
GET http://host:port/v1/models/${MODEL_NAME}[/versions/${MODEL_VERSION}]
```

`/versions/${MODEL_VERSION}` is optional. If omitted status for all versions is
returned in the response.

### Response format

If successful, returns a JSON representation of
[`GetModelStatusResponse`](https://github.com/tensorflow/serving/blob/5369880e9143aa00d586ee536c12b04e945a977c/tensorflow_serving/apis/get_model_status.proto#L64)
protobuf.

## Classify and Regress API

This API closely follows the `Classify` and `Regress` methods of
[`PredictionService`](https://github.com/tensorflow/serving/blob/5369880e9143aa00d586ee536c12b04e945a977c/tensorflow_serving/apis/prediction_service.proto#L15)
gRPC API.

### URL

```
POST http://host:port/v1/models/${MODEL_NAME}[/versions/${MODEL_VERSION}]:(classify|regress)
```

`/versions/${MODEL_VERSION}` is optional. If omitted the latest version is used.

### Request format

The request body for the `classify` and `regress` APIs must be a JSON object
Expand Down Expand Up @@ -128,6 +137,18 @@ Users of gRPC API will notice the similarity of this format with

## Predict API

This API closely follows the
[`PredictionService.Predict`](https://github.com/tensorflow/serving/blob/5369880e9143aa00d586ee536c12b04e945a977c/tensorflow_serving/apis/prediction_service.proto#L23)
gRPC API.

### URL

```
POST http://host:port/v1/models/${MODEL_NAME}[/versions/${MODEL_VERSION}]:predict
```

`/versions/${MODEL_VERSION}` is optional. If omitted the latest version is used.

### Request format

The request body for `predict` API must be JSON object formatted as follows:
Expand Down Expand Up @@ -377,8 +398,27 @@ $ tensorflow_model_server --rest_api_port=8501 \

### Make REST API calls to ModelServer

In a different terminal, use the `curl` tool to make REST API calls. A `predict`
call would look as follows:
In a different terminal, use the `curl` tool to make REST API calls.

Get status of the model as follows:

```
$ curl http://localhost:8501/v1/models/half_plus_three
{
"model_version_status": [
{
"version": "123",
"state": "AVAILABLE",
"status": {
"error_code": "OK",
"error_message": ""
}
}
]
}
```

A `predict` call would look as follows:

```shell
$ curl -d '{"instances": [1.0,2.0,5.0]}' -X POST http://localhost:8501/v1/models/half_plus_three:predict
Expand Down
1 change: 1 addition & 0 deletions tensorflow_serving/model_servers/BUILD
Expand Up @@ -259,6 +259,7 @@ cc_library(
hdrs = ["http_rest_api_handler.h"],
visibility = ["//visibility:public"],
deps = [
":get_model_status_impl",
":server_core",
"//tensorflow_serving/apis:model_proto",
"//tensorflow_serving/apis:predict_proto",
Expand Down
49 changes: 47 additions & 2 deletions tensorflow_serving/model_servers/http_rest_api_handler.cc
Expand Up @@ -17,16 +17,19 @@ limitations under the License.

#include <string>

#include "google/protobuf/util/json_util.h"
#include "absl/strings/escaping.h"
#include "absl/strings/numbers.h"
#include "absl/strings/str_cat.h"
#include "absl/strings/str_replace.h"
#include "absl/time/time.h"
#include "tensorflow/cc/saved_model/loader.h"
#include "tensorflow/cc/saved_model/signature_constants.h"
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow_serving/apis/model.pb.h"
#include "tensorflow_serving/apis/predict.pb.h"
#include "tensorflow_serving/core/servable_handle.h"
#include "tensorflow_serving/model_servers/get_model_status_impl.h"
#include "tensorflow_serving/model_servers/server_core.h"
#include "tensorflow_serving/servables/tensorflow/classification_service.h"
#include "tensorflow_serving/servables/tensorflow/predict_impl.h"
Expand All @@ -36,6 +39,8 @@ limitations under the License.
namespace tensorflow {
namespace serving {

using protobuf::util::JsonPrintOptions;
using protobuf::util::MessageToJsonString;
using tensorflow::serving::ServerCore;
using tensorflow::serving::TensorflowPredictor;

Expand All @@ -47,8 +52,9 @@ HttpRestApiHandler::HttpRestApiHandler(const RunOptions& run_options,
core_(core),
predictor_(new TensorflowPredictor(true /* use_saved_model */)),
prediction_api_regex_(
R"((?i)/v1/models/([^/:]+)(?:/versions/(\d+))?:(classify|regress|predict))") {
}
R"((?i)/v1/models/([^/:]+)(?:/versions/(\d+))?:(classify|regress|predict))"),
modelstatus_api_regex_(
R"((?i)/v1/models(?:/([^/:]+))?(?:/versions/(\d+))?)") {}

HttpRestApiHandler::~HttpRestApiHandler() {}

Expand Down Expand Up @@ -102,7 +108,12 @@ Status HttpRestApiHandler::ProcessRequest(
status = ProcessPredictRequest(model_name, model_version, request_body,
output);
}
} else if (http_method == "GET" &&
RE2::FullMatch(string(request_path), modelstatus_api_regex_,
&model_name, &model_version_str)) {
status = ProcessModelStatusRequest(model_name, model_version_str, output);
}

if (!status.ok()) {
FillJsonErrorMsg(status.error_message(), output);
}
Expand Down Expand Up @@ -174,6 +185,40 @@ Status HttpRestApiHandler::ProcessPredictRequest(
return Status::OK();
}

Status HttpRestApiHandler::ProcessModelStatusRequest(
const absl::string_view model_name,
const absl::string_view model_version_str, string* output) {
GetModelStatusRequest request;
// We do not yet support returning status of all models
// to be in-sync with the gRPC GetModelStatus API.
if (model_name.empty()) {
return errors::InvalidArgument("Missing model name in request.");
}
request.mutable_model_spec()->set_name(string(model_name));
if (!model_version_str.empty()) {
int64 version;
if (!absl::SimpleAtoi(model_version_str, &version)) {
return errors::InvalidArgument(
"Failed to convert version: ", model_version_str, " to numeric.");
}
request.mutable_model_spec()->mutable_version()->set_value(version);
}

GetModelStatusResponse response;
TF_RETURN_IF_ERROR(
GetModelStatusImpl::GetModelStatus(core_, request, &response));
JsonPrintOptions opts;
opts.add_whitespace = true;
opts.always_print_primitive_fields = true;
// Note this is protobuf::util::Status (not TF Status) object.
const auto& status = MessageToJsonString(response, output, opts);
if (!status.ok()) {
return errors::Internal("Failed to convert proto to json. Error: ",
status.ToString());
}
return Status::OK();
}

Status HttpRestApiHandler::GetInfoMap(
const ModelSpec& model_spec, const string& signature_name,
::google::protobuf::Map<string, tensorflow::TensorInfo>* infomap) {
Expand Down
24 changes: 13 additions & 11 deletions tensorflow_serving/model_servers/http_rest_api_handler.h
Expand Up @@ -41,15 +41,18 @@ class ModelSpec;
//
// Currently supported APIs are as follows:
//
// o Predict
// o Inference - Classify/Regress/Predict
//
// Paths:
// /v1/models/<model_name>:predict (uses 'latest' version of the model).
// /v1/models/<model_name>/versions/<version_number>:predict
// POST /v1/models/<model_name>:(classify|regress|predict)
// POST /v1/models/<model_name>/versions/<ver>:(classify|regress|predict)
//
// Request/Response format:
// https://cloud.google.com/ml-engine/docs/v1/predict-request
// o Model status
//
// GET /v1/models/<model_name> (status of all versions)
// GET /v1/models/<model_name>/versions/<ver> (status of specific version)
//
// The API is documented here:
// tensorflow_serving/g3doc/api_rest.md
//
// Users of this class should typically create one instance of it at process
// startup, register paths defined by kPathRegex with the in-process HTTP
Expand All @@ -73,10 +76,6 @@ class HttpRestApiHandler {

// Process a HTTP request.
//
// If `http_method` (e.g. POST) and `request_path` (e.g. /v1/models/m:predict)
// match one of the supported APIs, the body (JSON object) is processed and
// response (JSON object) is returned in `output` along with output `headers`.
//
// In case of errors, the `headers` and `output` are still relevant as they
// contain detailed error messages, that can be relayed back to the client.
Status ProcessRequest(const absl::string_view http_method,
Expand All @@ -98,14 +97,17 @@ class HttpRestApiHandler {
const absl::optional<int64>& model_version,
const absl::string_view request_body,
string* output);

Status ProcessModelStatusRequest(const absl::string_view model_name,
const absl::string_view model_version_str,
string* output);
Status GetInfoMap(const ModelSpec& model_spec, const string& signature_name,
::google::protobuf::Map<string, tensorflow::TensorInfo>* infomap);

const RunOptions run_options_;
ServerCore* core_;
std::unique_ptr<TensorflowPredictor> predictor_;
const RE2 prediction_api_regex_;
const RE2 modelstatus_api_regex_;
};

} // namespace serving
Expand Down
58 changes: 47 additions & 11 deletions tensorflow_serving/model_servers/http_rest_api_handler_test.cc
Expand Up @@ -179,22 +179,12 @@ TEST_F(HttpRestApiHandlerTest, UnsupportedApiCalls) {

status = handler_.ProcessRequest("GET", "/v1/models", "", &headers, &output);
EXPECT_TRUE(errors::IsInvalidArgument(status));
EXPECT_THAT(status.error_message(), HasSubstr("Malformed request"));
EXPECT_THAT(status.error_message(), HasSubstr("Missing model name"));

status = handler_.ProcessRequest("POST", "/v1/models", "", &headers, &output);
EXPECT_TRUE(errors::IsInvalidArgument(status));
EXPECT_THAT(status.error_message(), HasSubstr("Malformed request"));

status =
handler_.ProcessRequest("GET", "/v1/models/foo", "", &headers, &output);
EXPECT_TRUE(errors::IsInvalidArgument(status));
EXPECT_THAT(status.error_message(), HasSubstr("Malformed request"));

status = handler_.ProcessRequest("GET", "/v1/models/foo/version/50", "",
&headers, &output);
EXPECT_TRUE(errors::IsInvalidArgument(status));
EXPECT_THAT(status.error_message(), HasSubstr("Malformed request"));

status = handler_.ProcessRequest("GET", "/v1/models/foo:predict", "",
&headers, &output);
EXPECT_TRUE(errors::IsInvalidArgument(status));
Expand Down Expand Up @@ -370,6 +360,52 @@ TEST_F(HttpRestApiHandlerTest, Classify) {
(HeaderList){{"Content-Type", "application/json"}}));
}

TEST_F(HttpRestApiHandlerTest, GetStatus) {
HeaderList headers;
string output;
Status status;

// Get status for all versions.
TF_EXPECT_OK(handler_.ProcessRequest(
"GET", absl::StrCat("/v1/models/", kTestModelName), "", &headers,
&output));
EXPECT_THAT(headers, UnorderedElementsAreArray(
(HeaderList){{"Content-Type", "application/json"}}));
TF_EXPECT_OK(CompareJson(output, R"({
"model_version_status": [
{
"version": "123",
"state": "AVAILABLE",
"status": {
"error_code": "OK",
"error_message": ""
}
}
]
})"));

// Get status of specific version.
TF_EXPECT_OK(
handler_.ProcessRequest("GET",
absl::StrCat("/v1/models/", kTestModelName,
"/versions/", kTestModelVersion1),
"", &headers, &output));
EXPECT_THAT(headers, UnorderedElementsAreArray(
(HeaderList){{"Content-Type", "application/json"}}));
TF_EXPECT_OK(CompareJson(output, R"({
"model_version_status": [
{
"version": "123",
"state": "AVAILABLE",
"status": {
"error_code": "OK",
"error_message": ""
}
}
]
})"));
}

} // namespace
} // namespace serving
} // namespace tensorflow

0 comments on commit 00e459f

Please sign in to comment.