Skip to content

Commit

Permalink
Replace the global registration with a registration class so that whe…
Browse files Browse the repository at this point in the history
…n we move server_init_internal to OSS we won't run into undetermined global registration sequence issue.

PiperOrigin-RevId: 592632789
  • Loading branch information
tensorflower-gardener authored and tensorflow-copybara committed Dec 20, 2023
1 parent a635552 commit 21d8f88
Show file tree
Hide file tree
Showing 7 changed files with 149 additions and 67 deletions.
7 changes: 7 additions & 0 deletions tensorflow_serving/model_servers/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -198,6 +198,7 @@ cc_test(
"@com_google_absl//absl/strings",
"@com_google_absl//absl/time",
"@com_googlesource_code_re2//:re2",
"@local_tsl//tsl/platform:errors",
"@org_tensorflow//tensorflow/cc/saved_model:loader",
"@org_tensorflow//tensorflow/cc/saved_model:signature_constants",
"@org_tensorflow//tensorflow/core:lib",
Expand Down Expand Up @@ -274,6 +275,7 @@ cc_library(
"//tensorflow_serving/servables/tensorflow:util",
"//tensorflow_serving/util:prometheus_exporter",
"//tensorflow_serving/util:threadpool_executor",
"//tensorflow_serving/util/net_http/public:shared_files",
"//tensorflow_serving/util/net_http/server/public:http_server",
"//tensorflow_serving/util/net_http/server/public:http_server_api",
"@com_google_absl//absl/strings",
Expand Down Expand Up @@ -437,9 +439,12 @@ cc_library(
":prediction_service_impl",
":server_core",
"//tensorflow_serving/apis:prediction_service_cc_proto",
"//tensorflow_serving/model_servers:prediction_service_util",
"//tensorflow_serving/servables/tensorflow:saved_model_bundle_source_adapter",
"//tensorflow_serving/servables/tensorflow:session_bundle_config_cc_proto",
"@com_github_grpc_grpc//:grpc++",
"@com_google_absl//absl/log",
"@com_google_absl//absl/strings:string_view",
"@com_google_protobuf//:cc_wkt_protos",
"@org_tensorflow//tensorflow/core:lib",
],
Expand Down Expand Up @@ -479,11 +484,13 @@ cc_library(
"@com_github_grpc_grpc//:grpc++",
"@com_google_absl//absl/memory",
"@com_google_protobuf//:cc_wkt_protos",
"@local_tsl//tsl/platform:errors",
"@org_tensorflow//tensorflow/c:c_api",
"@org_tensorflow//tensorflow/cc/saved_model:tag_constants",
"@org_tensorflow//tensorflow/core:lib",
"@org_tensorflow//tensorflow/core:protos_all_cc",
"@org_tensorflow//tensorflow/core:tensorflow",
"@org_tensorflow//tensorflow/core/kernels/batching_util:periodic_function_dynamic",
"@org_tensorflow//tensorflow/core/profiler/rpc:profiler_service_impl",
] + SUPPORTED_TENSORFLOW_OPS,
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ limitations under the License.
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/lib/core/status_test_util.h"
#include "tensorflow/core/platform/env.h"
#include "tsl/platform/errors.h"
#include "tensorflow_serving/core/availability_preserving_policy.h"
#include "tensorflow_serving/model_servers/model_platform_types.h"
#include "tensorflow_serving/model_servers/platform_config_util.h"
Expand Down Expand Up @@ -96,9 +97,10 @@ class HttpRestApiHandlerTest : public ::testing::Test {
ServerCore::Options options;
options.model_server_config = config;

TF_RETURN_IF_ERROR(
tensorflow::serving::init::SetupPlatformConfigMapForTensorFlow(
SessionBundleConfig(), options.platform_config_map));
auto* tf_serving_registry =
init::TensorflowServingFunctionRegistration::GetRegistry();
TF_RETURN_IF_ERROR(tf_serving_registry->GetSetupPlatformConfigMap()(
SessionBundleConfig(), options.platform_config_map));
// Reduce the number of initial load threads to be num_load_threads to avoid
// timing out in tests.
options.num_initial_load_threads = options.num_load_threads;
Expand Down
10 changes: 6 additions & 4 deletions tensorflow_serving/model_servers/http_server.cc
Original file line number Diff line number Diff line change
Expand Up @@ -130,10 +130,12 @@ class RequestExecutor final : public net_http::EventExecutor {
class RestApiRequestDispatcher {
public:
RestApiRequestDispatcher(int timeout_in_ms, ServerCore* core)
: regex_(HttpRestApiHandler::kPathRegex),
core_(core),
handler_(tensorflow::serving::init::CreateHttpRestApiHandler(
timeout_in_ms, core)) {}
: regex_(HttpRestApiHandler::kPathRegex), core_(core) {
auto* tf_serving_registry = tensorflow::serving::init::
TensorflowServingFunctionRegistration::GetRegistry();
handler_ =
tf_serving_registry->GetCreateHttpRestApiHandler()(timeout_in_ms, core);
}

net_http::RequestHandler Dispatch(net_http::ServerRequestInterface* req) {
if (RE2::FullMatch(string(req->uri_path()), regex_)) {
Expand Down
18 changes: 10 additions & 8 deletions tensorflow_serving/model_servers/server.cc
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ limitations under the License.
#include "tensorflow/core/platform/protobuf.h"
#include "tensorflow/core/profiler/rpc/profiler_service_impl.h"
#include "tensorflow/core/protobuf/config.pb.h"
#include "tsl/platform/errors.h"
#include "tensorflow_serving/config/model_server_config.pb.h"
#include "tensorflow_serving/config/monitoring_config.pb.h"
#include "tensorflow_serving/config/platform_config.pb.h"
Expand Down Expand Up @@ -212,6 +213,9 @@ Status Server::BuildAndStart(const Options& server_options) {
server_options.model_config_file, &options.model_server_config));
}

auto* tf_serving_registry =
init::TensorflowServingFunctionRegistration::GetRegistry();

if (server_options.platform_config_file.empty()) {
SessionBundleConfig session_bundle_config;
// Batching config
Expand Down Expand Up @@ -290,15 +294,13 @@ Status Server::BuildAndStart(const Options& server_options) {
server_options.num_tflite_interpreters_per_pool);
session_bundle_config.set_num_tflite_pools(server_options.num_tflite_pools);

TF_RETURN_IF_ERROR(
tensorflow::serving::init::SetupPlatformConfigMapForTensorFlow(
session_bundle_config, options.platform_config_map));
TF_RETURN_IF_ERROR(tf_serving_registry->GetSetupPlatformConfigMap()(
session_bundle_config, options.platform_config_map));
} else {
TF_RETURN_IF_ERROR(ParseProtoTextFile<PlatformConfigMap>(
server_options.platform_config_file, &options.platform_config_map));
TF_RETURN_IF_ERROR(
tensorflow::serving::init::UpdatePlatformConfigMapForTensorFlow(
options.platform_config_map));
TF_RETURN_IF_ERROR(tf_serving_registry->GetUpdatePlatformConfigMap()(
options.platform_config_map));
}

options.custom_model_config_loader = &LoadCustomModelConfig;
Expand Down Expand Up @@ -357,8 +359,8 @@ Status Server::BuildAndStart(const Options& server_options) {
&thread_pool_factory_));
}
predict_server_options.thread_pool_factory = thread_pool_factory_.get();
prediction_service_ = tensorflow::serving::init::CreatePredictionService(
predict_server_options);
prediction_service_ =
tf_serving_registry->GetCreatePredictionService()(predict_server_options);

::grpc::ServerBuilder builder;
// If defined, listen to a tcp port for gRPC/HTTP.
Expand Down
25 changes: 15 additions & 10 deletions tensorflow_serving/model_servers/server_init.cc
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@ limitations under the License.

#include <memory>

#include "absl/log/log.h"
#include "absl/strings/string_view.h"
#include "tensorflow_serving/model_servers/http_rest_api_handler.h"
#include "tensorflow_serving/model_servers/platform_config_util.h"
#include "tensorflow_serving/model_servers/prediction_service_impl.h"
Expand Down Expand Up @@ -49,16 +51,19 @@ std::unique_ptr<PredictionService::Service> CreatePredictionServiceImpl(
return absl::make_unique<PredictionServiceImpl>(options);
}

ABSL_CONST_INIT SetupPlatformConfigMapForTensorFlowFnType
SetupPlatformConfigMapForTensorFlow =
SetupPlatformConfigMapForTensorFlowImpl;
ABSL_CONST_INIT UpdatePlatformConfigMapForTensorFlowFnType
UpdatePlatformConfigMapForTensorFlow =
UpdatePlatformConfigMapForTensorFlowImpl;
ABSL_CONST_INIT CreateHttpRestApiHandlerFnType CreateHttpRestApiHandler =
CreateHttpRestApiHandlerImpl;
ABSL_CONST_INIT CreatePredictionServiceFnType CreatePredictionService =
CreatePredictionServiceImpl;
void TensorflowServingFunctionRegistration::Register(
absl::string_view type,
SetupPlatformConfigMapForTensorFlowFnType setup_platform_config_map_func,
UpdatePlatformConfigMapForTensorFlowFnType update_platform_config_map_func,
CreateHttpRestApiHandlerFnType create_http_rest_api_handler_func,
CreatePredictionServiceFnType create_prediction_service_func) {
VLOG(1) << "Registering serving functions for " << type;
registration_type_ = type;
setup_platform_config_map_ = setup_platform_config_map_func;
update_platform_config_map_ = update_platform_config_map_func;
create_http_rest_api_handler_ = create_http_rest_api_handler_func;
create_prediction_service_ = create_prediction_service_func;
}

} // namespace init
} // namespace serving
Expand Down
86 changes: 69 additions & 17 deletions tensorflow_serving/model_servers/server_init.h
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@ limitations under the License.
#ifndef THIRD_PARTY_TENSORFLOW_SERVING_MODEL_SERVERS_SERVER_INIT_H_
#define THIRD_PARTY_TENSORFLOW_SERVING_MODEL_SERVERS_SERVER_INIT_H_

#include <string>

#include "google/protobuf/any.pb.h"
#include "tensorflow/core/lib/core/status.h"
#include "tensorflow_serving/apis/prediction_service.grpc.pb.h"
Expand All @@ -38,28 +40,78 @@ using CreatePredictionServiceFnType =
std::unique_ptr<PredictionService::Service> (*)(
const PredictionServiceOptions&);

Status SetupPlatformConfigMapImpl(const SessionBundleConfig&,
PlatformConfigMap&);
Status SetupPlatformConfigMapFromConfigFileImpl(const string&,
PlatformConfigMap&);
Status SetupPlatformConfigMapForTensorFlowImpl(const SessionBundleConfig&,
PlatformConfigMap&);
Status UpdatePlatformConfigMapForTensorFlowImpl(PlatformConfigMap&);
std::unique_ptr<HttpRestApiHandlerBase> CreateHttpRestApiHandlerImpl(
int, ServerCore*);
std::unique_ptr<PredictionService::Service> CreatePredictionServiceImpl(
const PredictionServiceOptions&);

// Setup the 'TensorFlow' PlatformConfigMap from the specified
// SessionBundleConfig.
extern SetupPlatformConfigMapForTensorFlowFnType
SetupPlatformConfigMapForTensorFlow;
// If the PlatformConfigMap contains the config for the 'TensorFlow' platform,
// update the PlatformConfigMap when necessary.
extern UpdatePlatformConfigMapForTensorFlowFnType
UpdatePlatformConfigMapForTensorFlow;
// Create an HttpRestApiHandler object that handles HTTP/REST request APIs for
// serving.
extern CreateHttpRestApiHandlerFnType CreateHttpRestApiHandler;
// Create a PredictionService object that handles gRPC request APIs for serving.
extern CreatePredictionServiceFnType CreatePredictionService;
// Register the tensorflow serving functions.
class TensorflowServingFunctionRegistration {
public:
virtual ~TensorflowServingFunctionRegistration() = default;

// Get the registry singleton.
static TensorflowServingFunctionRegistration* GetRegistry() {
static auto* registration = new TensorflowServingFunctionRegistration();
return registration;
}

// The tensorflow serving function registration. For TFRT, the TFRT
// registration will overwrite the Tensorflow registration.
void Register(
absl::string_view type,
SetupPlatformConfigMapForTensorFlowFnType setup_platform_config_map_func,
UpdatePlatformConfigMapForTensorFlowFnType
update_platform_config_map_func,
CreateHttpRestApiHandlerFnType create_http_rest_api_handler_func,
CreatePredictionServiceFnType create_prediction_service_func);

bool IsRegistered() const { return !registration_type_.empty(); }

SetupPlatformConfigMapForTensorFlowFnType GetSetupPlatformConfigMap() const {
return setup_platform_config_map_;
}

UpdatePlatformConfigMapForTensorFlowFnType GetUpdatePlatformConfigMap()
const {
return update_platform_config_map_;
}

CreateHttpRestApiHandlerFnType GetCreateHttpRestApiHandler() const {
return create_http_rest_api_handler_;
}

CreatePredictionServiceFnType GetCreatePredictionService() const {
return create_prediction_service_;
}

private:
TensorflowServingFunctionRegistration() {
Register("tensorflow", init::SetupPlatformConfigMapForTensorFlowImpl,
init::UpdatePlatformConfigMapForTensorFlowImpl,
init::CreateHttpRestApiHandlerImpl,
init::CreatePredictionServiceImpl);
}

// The registration type, indicating the platform, e.g. tensorflow, tfrt.
std::string registration_type_ = "";

// Setup the 'TensorFlow' PlatformConfigMap from the specified
// SessionBundleConfig.
SetupPlatformConfigMapForTensorFlowFnType setup_platform_config_map_;
// If the PlatformConfigMap contains the config for the 'TensorFlow'
// platform, update the PlatformConfigMap when necessary.
UpdatePlatformConfigMapForTensorFlowFnType update_platform_config_map_;
// Create an HttpRestApiHandler object that handles HTTP/REST request APIs
// for serving.
CreateHttpRestApiHandlerFnType create_http_rest_api_handler_;
// Create a PredictionService object that handles gRPC request APIs for
// serving.
CreatePredictionServiceFnType create_prediction_service_;
};

} // namespace init
} // namespace serving
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -129,32 +129,36 @@ def SortedObject(obj):
return obj


def GetArgsKey(*args, **kwargs):
return args + tuple(sorted(kwargs.items()))


class TensorflowModelServerTestBase(tf.test.TestCase):
"""This class defines integration test cases for tensorflow_model_server."""

@staticmethod
def __TestSrcDirPath(relative_path=''):
@classmethod
def TestSrcDirPath(cls, relative_path=''):
return os.path.join(os.environ['TEST_SRCDIR'],
'tf_serving/tensorflow_serving', relative_path)

@staticmethod
def GetArgsKey(*args, **kwargs):
return args + tuple(sorted(kwargs.items()))

# Maps string key -> 2-tuple of 'host:port' string.
model_servers_dict = {}

@staticmethod
def RunServer(model_name,
model_path,
model_type='tf',
model_config_file=None,
monitoring_config_file=None,
batching_parameters_file=None,
grpc_channel_arguments='',
wait_for_server_ready=True,
pipe=None,
model_config_file_poll_period=None):
@classmethod
def RunServer(
cls,
model_name,
model_path,
model_server_path='model_servers',
model_type='tf',
model_config_file=None,
monitoring_config_file=None,
batching_parameters_file=None,
grpc_channel_arguments='',
wait_for_server_ready=True,
pipe=None,
model_config_file_poll_period=None,
):
"""Run tensorflow_model_server using test config.
A unique instance of server is started for each set of arguments.
Expand All @@ -164,33 +168,36 @@ def RunServer(model_name,
Args:
model_name: Name of model.
model_path: Path to model.
model_server_path: The additional model server path dir.
model_type: Type of model TensorFlow ('tf') or TF Lite ('tflite').
model_config_file: Path to model config file.
monitoring_config_file: Path to the monitoring config file.
batching_parameters_file: Path to batching parameters.
grpc_channel_arguments: Custom gRPC args for server.
wait_for_server_ready: Wait for gRPC port to be ready.
pipe: subpipe.PIPE object to read stderr from server.
model_config_file_poll_period: Period for polling the
filesystem to discover new model configs.
model_config_file_poll_period: Period for polling the filesystem to
discover new model configs.
Returns:
3-tuple (<Popen object>, <grpc host:port>, <rest host:port>).
Raises:
ValueError: when both model_path and config_file is empty.
"""
args_key = TensorflowModelServerTestBase.GetArgsKey(**locals())
args_key = GetArgsKey(**locals())
if args_key in TensorflowModelServerTestBase.model_servers_dict:
return TensorflowModelServerTestBase.model_servers_dict[args_key]
port = PickUnusedPort()
rest_api_port = PickUnusedPort()
print(('Starting test server on port: {} for model_name: '
'{}/model_config_file: {}'.format(port, model_name,
model_config_file)))

command = os.path.join(
TensorflowModelServerTestBase.__TestSrcDirPath('model_servers'),
'tensorflow_model_server')
TensorflowModelServerTestBase.TestSrcDirPath(model_server_path),
'tensorflow_model_server',
)
command += ' --port=' + str(port)
command += ' --rest_api_port=' + str(rest_api_port)
command += ' --rest_api_timeout_in_ms=' + str(HTTP_REST_TIMEOUT_MS)
Expand Down Expand Up @@ -357,19 +364,24 @@ def _TestPredict(
self,
model_path,
batching_parameters_file=None,
signature_name=signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY):
signature_name=signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY,
model_server_path='model_servers',
):
"""Helper method to test prediction.
Args:
model_path: Path to the model on disk.
batching_parameters_file: Batching parameters file to use (if None
batching is not enabled).
batching is not enabled).
signature_name: Signature name to expect in the PredictResponse.
model_server_path: The model server path dir.
"""
model_server_address = TensorflowModelServerTestBase.RunServer(
'default',
model_path,
batching_parameters_file=batching_parameters_file)[1]
batching_parameters_file=batching_parameters_file,
model_server_path=model_server_path,
)[1]
expected_version = self._GetModelVersion(model_path)
self.VerifyPredictRequest(model_server_address, expected_output=3.0,
expected_version=expected_version,
Expand Down

0 comments on commit 21d8f88

Please sign in to comment.