Skip to content

Commit

Permalink
Add support to use a MockServable in MockServerCore.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 612937820
  • Loading branch information
kenfranko authored and tensorflow-copybara committed Mar 5, 2024
1 parent 5b5d30f commit 5b6e0b6
Show file tree
Hide file tree
Showing 4 changed files with 40 additions and 0 deletions.
1 change: 1 addition & 0 deletions tensorflow_serving/model_servers/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,7 @@ cc_library(
"//tensorflow_serving/resources:resource_values",
"//tensorflow_serving/servables/tensorflow:predict_util",
"//tensorflow_serving/servables/tensorflow:saved_model_bundle_source_adapter",
"//tensorflow_serving/servables/tensorflow:servable",
"//tensorflow_serving/sources/storage_path:file_system_storage_path_source",
"//tensorflow_serving/util:event_bus",
"//tensorflow_serving/util:unique_ptr_with_deps",
Expand Down
8 changes: 8 additions & 0 deletions tensorflow_serving/model_servers/server_core.h
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@ limitations under the License.
#include "tensorflow_serving/core/storage_path.h"
#include "tensorflow_serving/core/stream_logger.h"
#include "tensorflow_serving/servables/tensorflow/predict_util.h"
#include "tensorflow_serving/servables/tensorflow/servable.h"
#include "tensorflow_serving/sources/storage_path/file_system_storage_path_source.h"
#include "tensorflow_serving/util/event_bus.h"
#include "tensorflow_serving/util/unique_ptr_with_deps.h"
Expand Down Expand Up @@ -268,6 +269,13 @@ class ServerCore : public Manager {
return Status();
}

// This specialized version allows us to override GetServableHandle for
// Servables in sub-classes. Useful for testing.
virtual Status GetServableHandle(const ModelSpec& model_spec,
ServableHandle<Servable>* const handle) {
return GetServableHandle<Servable>(model_spec, handle);
}

/// Writes the log for the particular request, response and metadata, if we
/// decide to sample it and if request-logging was configured for the
/// particular model.
Expand Down
5 changes: 5 additions & 0 deletions tensorflow_serving/model_servers/test_util/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -37,14 +37,19 @@ cc_library(
"//tensorflow_serving/config:platform_config_cc_proto",
"//tensorflow_serving/core:aspired_versions_manager",
"//tensorflow_serving/core:servable_handle",
"//tensorflow_serving/core:servable_id",
"//tensorflow_serving/core:servable_state",
"//tensorflow_serving/core:servable_state_monitor",
"//tensorflow_serving/core:server_request_logger",
"//tensorflow_serving/core/test_util:fake_loader_source_adapter",
"//tensorflow_serving/core/test_util:fake_loader_source_adapter_cc_proto",
"//tensorflow_serving/core/test_util:servable_handle_test_util",
"//tensorflow_serving/model_servers:server_core",
"//tensorflow_serving/servables/tensorflow:mock_servable",
"//tensorflow_serving/servables/tensorflow:servable",
"//tensorflow_serving/util:event_bus",
"//tensorflow_serving/util:unique_ptr_with_deps",
"@com_google_absl//absl/status",
"@com_google_googletest//:gtest",
"@com_google_protobuf//:cc_wkt_protos",
"@com_google_protobuf//:protobuf_lite",
Expand Down
26 changes: 26 additions & 0 deletions tensorflow_serving/model_servers/test_util/mock_server_core.h
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ limitations under the License.
#include "google/protobuf/any.pb.h"
#include "google/protobuf/map.h"
#include <gmock/gmock.h>
#include "absl/status/status.h"
#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/platform/logging.h"
#include "tensorflow_serving/apis/logging.proto.h"
Expand All @@ -32,11 +33,14 @@ limitations under the License.
#include "tensorflow_serving/config/platform_config.pb.h"
#include "tensorflow_serving/core/aspired_versions_manager.h"
#include "tensorflow_serving/core/servable_handle.h"
#include "tensorflow_serving/core/servable_id.h"
#include "tensorflow_serving/core/servable_state.h"
#include "tensorflow_serving/core/servable_state_monitor.h"
#include "tensorflow_serving/core/server_request_logger.h"
#include "tensorflow_serving/core/test_util/fake_loader_source_adapter.pb.h"
#include "tensorflow_serving/core/test_util/servable_handle_test_util.h"
#include "tensorflow_serving/model_servers/server_core.h"
#include "tensorflow_serving/servables/tensorflow/servable.h"
#include "tensorflow_serving/util/event_bus.h"
#include "tensorflow_serving/util/unique_ptr_with_deps.h"

Expand Down Expand Up @@ -84,6 +88,7 @@ class MockServerCore : public ServerCore {

explicit MockServerCore(const PlatformConfigMap& platform_config_map)
: MockServerCore(platform_config_map, nullptr) {}

MockServerCore(const PlatformConfigMap& platform_config_map,
std::unique_ptr<ServerRequestLogger> server_request_logger)
: ServerCore(GetOptions(platform_config_map,
Expand All @@ -97,11 +102,32 @@ class MockServerCore : public ServerCore {
const LogMetadata& log_metadata),
(override));

// Sets the Servable used by GetServableHandle
void SetServable(std::unique_ptr<Servable> servable) {
servable_ = std::move(servable);
}

template <typename T>
Status GetServableHandle(const ModelSpec& model_spec,
ServableHandle<T>* const handle) {
LOG(FATAL) << "Not implemented.";
}

// Implement GetServable for type Servable. Will return the Servable
// set by SetServable, otherwise forwards to base class.
virtual Status GetServableHandle(
const ModelSpec& model_spec,
ServableHandle<Servable>* const handle) override {
if (servable_) {
const ServableId id = {"servable", 0};
*handle = WrapAsHandle<Servable>(id, servable_.get());
return absl::OkStatus();
} else {
return ServerCore::GetServableHandle<Servable>(model_spec, handle);
}
}

std::unique_ptr<Servable> servable_;
};

} // namespace test_util
Expand Down

0 comments on commit 5b6e0b6

Please sign in to comment.