Skip to content

Commit

Permalink
Automated Code Change
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 609201759
  • Loading branch information
ckennelly authored and tensorflow-copybara committed Feb 22, 2024
1 parent 9d525a2 commit f761fc7
Show file tree
Hide file tree
Showing 2 changed files with 20 additions and 28 deletions.
24 changes: 10 additions & 14 deletions tensorflow_serving/model_servers/http_rest_api_handler.cc
Original file line number Diff line number Diff line change
Expand Up @@ -115,14 +115,13 @@ Status HttpRestApiHandler::ProcessClassifyRequest(
const absl::string_view request_body, string* output) {
::google::protobuf::Arena arena;

auto* request = ::google::protobuf::Arena::CreateMessage<ClassificationRequest>(&arena);
auto* request = ::google::protobuf::Arena::Create<ClassificationRequest>(&arena);
TF_RETURN_IF_ERROR(FillModelSpecWithNameVersionAndLabel(
model_name, model_version, model_version_label,
request->mutable_model_spec()));
TF_RETURN_IF_ERROR(FillClassificationRequestFromJson(request_body, request));

auto* response =
::google::protobuf::Arena::CreateMessage<ClassificationResponse>(&arena);
auto* response = ::google::protobuf::Arena::Create<ClassificationResponse>(&arena);
TF_RETURN_IF_ERROR(TensorflowClassificationServiceImpl::Classify(
run_options_, core_, thread::ThreadPoolOptions(), *request, response));
TF_RETURN_IF_ERROR(
Expand All @@ -137,13 +136,13 @@ Status HttpRestApiHandler::ProcessRegressRequest(
const absl::string_view request_body, string* output) {
::google::protobuf::Arena arena;

auto* request = ::google::protobuf::Arena::CreateMessage<RegressionRequest>(&arena);
auto* request = ::google::protobuf::Arena::Create<RegressionRequest>(&arena);
TF_RETURN_IF_ERROR(FillModelSpecWithNameVersionAndLabel(
model_name, model_version, model_version_label,
request->mutable_model_spec()));
TF_RETURN_IF_ERROR(FillRegressionRequestFromJson(request_body, request));

auto* response = ::google::protobuf::Arena::CreateMessage<RegressionResponse>(&arena);
auto* response = ::google::protobuf::Arena::Create<RegressionResponse>(&arena);
TF_RETURN_IF_ERROR(TensorflowRegressionServiceImpl::Regress(
run_options_, core_, thread::ThreadPoolOptions(), *request, response));
TF_RETURN_IF_ERROR(MakeJsonFromRegressionResult(response->result(), output));
Expand All @@ -157,7 +156,7 @@ Status HttpRestApiHandler::ProcessPredictRequest(
const absl::string_view request_body, string* output) {
::google::protobuf::Arena arena;

auto* request = ::google::protobuf::Arena::CreateMessage<PredictRequest>(&arena);
auto* request = ::google::protobuf::Arena::Create<PredictRequest>(&arena);
TF_RETURN_IF_ERROR(FillModelSpecWithNameVersionAndLabel(
model_name, model_version, model_version_label,
request->mutable_model_spec()));
Expand All @@ -171,7 +170,7 @@ Status HttpRestApiHandler::ProcessPredictRequest(
},
request, &format));

auto* response = ::google::protobuf::Arena::CreateMessage<PredictResponse>(&arena);
auto* response = ::google::protobuf::Arena::Create<PredictResponse>(&arena);
TF_RETURN_IF_ERROR(
predictor_->Predict(run_options_, core_, *request, response));
TF_RETURN_IF_ERROR(MakeJsonFromTensors(response->outputs(), format, output));
Expand All @@ -191,13 +190,12 @@ Status HttpRestApiHandler::ProcessModelStatusRequest(

::google::protobuf::Arena arena;

auto* request = ::google::protobuf::Arena::CreateMessage<GetModelStatusRequest>(&arena);
auto* request = ::google::protobuf::Arena::Create<GetModelStatusRequest>(&arena);
TF_RETURN_IF_ERROR(FillModelSpecWithNameVersionAndLabel(
model_name, model_version, model_version_label,
request->mutable_model_spec()));

auto* response =
::google::protobuf::Arena::CreateMessage<GetModelStatusResponse>(&arena);
auto* response = ::google::protobuf::Arena::Create<GetModelStatusResponse>(&arena);
TF_RETURN_IF_ERROR(
GetModelStatusImpl::GetModelStatus(core_, *request, response));
return ToJsonString(*response, output);
Expand All @@ -214,16 +212,14 @@ Status HttpRestApiHandler::ProcessModelMetadataRequest(

::google::protobuf::Arena arena;

auto* request =
::google::protobuf::Arena::CreateMessage<GetModelMetadataRequest>(&arena);
auto* request = ::google::protobuf::Arena::Create<GetModelMetadataRequest>(&arena);
// We currently only support the kSignatureDef metadata field
request->add_metadata_field(GetModelMetadataImpl::kSignatureDef);
TF_RETURN_IF_ERROR(FillModelSpecWithNameVersionAndLabel(
model_name, model_version, model_version_label,
request->mutable_model_spec()));

auto* response =
::google::protobuf::Arena::CreateMessage<GetModelMetadataResponse>(&arena);
auto* response = ::google::protobuf::Arena::Create<GetModelMetadataResponse>(&arena);
TF_RETURN_IF_ERROR(
GetModelMetadataImpl::GetModelMetadata(core_, *request, response));
return ToJsonString(*response, output);
Expand Down
24 changes: 10 additions & 14 deletions tensorflow_serving/model_servers/tfrt_http_rest_api_handler.cc
Original file line number Diff line number Diff line change
Expand Up @@ -122,14 +122,13 @@ Status TFRTHttpRestApiHandler::ProcessClassifyRequest(
const Servable::RunOptions& run_options, std::string* output) {
::google::protobuf::Arena arena;

auto* request = ::google::protobuf::Arena::CreateMessage<ClassificationRequest>(&arena);
auto* request = ::google::protobuf::Arena::Create<ClassificationRequest>(&arena);
TF_RETURN_IF_ERROR(FillModelSpecWithNameVersionAndLabel(
model_name, model_version, model_version_label,
request->mutable_model_spec()));
TF_RETURN_IF_ERROR(FillClassificationRequestFromJson(request_body, request));

auto* response =
::google::protobuf::Arena::CreateMessage<ClassificationResponse>(&arena);
auto* response = ::google::protobuf::Arena::Create<ClassificationResponse>(&arena);
ServableHandle<Servable> servable;
TF_RETURN_IF_ERROR(
core_->GetServableHandle(request->model_spec(), &servable));
Expand All @@ -147,13 +146,13 @@ Status TFRTHttpRestApiHandler::ProcessRegressRequest(
const Servable::RunOptions& run_options, std::string* output) {
::google::protobuf::Arena arena;

auto* request = ::google::protobuf::Arena::CreateMessage<RegressionRequest>(&arena);
auto* request = ::google::protobuf::Arena::Create<RegressionRequest>(&arena);
TF_RETURN_IF_ERROR(FillModelSpecWithNameVersionAndLabel(
model_name, model_version, model_version_label,
request->mutable_model_spec()));
TF_RETURN_IF_ERROR(FillRegressionRequestFromJson(request_body, request));

auto* response = ::google::protobuf::Arena::CreateMessage<RegressionResponse>(&arena);
auto* response = ::google::protobuf::Arena::Create<RegressionResponse>(&arena);
ServableHandle<Servable> servable;
TF_RETURN_IF_ERROR(
core_->GetServableHandle(request->model_spec(), &servable));
Expand All @@ -169,7 +168,7 @@ Status TFRTHttpRestApiHandler::ProcessPredictRequest(
const Servable::RunOptions& run_options, std::string* output) {
::google::protobuf::Arena arena;

auto* request = ::google::protobuf::Arena::CreateMessage<PredictRequest>(&arena);
auto* request = ::google::protobuf::Arena::Create<PredictRequest>(&arena);
TF_RETURN_IF_ERROR(FillModelSpecWithNameVersionAndLabel(
model_name, model_version, model_version_label,
request->mutable_model_spec()));
Expand All @@ -183,7 +182,7 @@ Status TFRTHttpRestApiHandler::ProcessPredictRequest(
},
request, &format));

auto* response = ::google::protobuf::Arena::CreateMessage<PredictResponse>(&arena);
auto* response = ::google::protobuf::Arena::Create<PredictResponse>(&arena);

ServableHandle<Servable> servable;
TF_RETURN_IF_ERROR(
Expand All @@ -207,13 +206,12 @@ Status TFRTHttpRestApiHandler::ProcessModelStatusRequest(

::google::protobuf::Arena arena;

auto* request = ::google::protobuf::Arena::CreateMessage<GetModelStatusRequest>(&arena);
auto* request = ::google::protobuf::Arena::Create<GetModelStatusRequest>(&arena);
TF_RETURN_IF_ERROR(FillModelSpecWithNameVersionAndLabel(
model_name, model_version, model_version_label,
request->mutable_model_spec()));

auto* response =
::google::protobuf::Arena::CreateMessage<GetModelStatusResponse>(&arena);
auto* response = ::google::protobuf::Arena::Create<GetModelStatusResponse>(&arena);
TF_RETURN_IF_ERROR(
GetModelStatusImpl::GetModelStatus(core_, *request, response));
return ToJsonString(*response, output);
Expand All @@ -230,16 +228,14 @@ Status TFRTHttpRestApiHandler::ProcessModelMetadataRequest(

::google::protobuf::Arena arena;

auto* request =
::google::protobuf::Arena::CreateMessage<GetModelMetadataRequest>(&arena);
auto* request = ::google::protobuf::Arena::Create<GetModelMetadataRequest>(&arena);
// We currently only support the kSignatureDef metadata field
request->add_metadata_field(std::string(kSignatureDef));
TF_RETURN_IF_ERROR(FillModelSpecWithNameVersionAndLabel(
model_name, model_version, model_version_label,
request->mutable_model_spec()));

auto* response =
::google::protobuf::Arena::CreateMessage<GetModelMetadataResponse>(&arena);
auto* response = ::google::protobuf::Arena::Create<GetModelMetadataResponse>(&arena);
TF_RETURN_IF_ERROR(
TFRTGetModelMetadataImpl::GetModelMetadata(core_, *request, response));

Expand Down

0 comments on commit f761fc7

Please sign in to comment.