Skip to content

Commit

Permalink
Merge pull request #86 from chrisolston/branch_122431817
Browse files Browse the repository at this point in the history
Upstream changes from internal
  • Loading branch information
Christopher Olston committed May 20, 2016
2 parents 9769786 + 235bc42 commit b4e9815
Show file tree
Hide file tree
Showing 15 changed files with 427 additions and 61 deletions.
3 changes: 2 additions & 1 deletion tensorflow_serving/core/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -593,10 +593,11 @@ cc_test(
":servable_state",
":servable_state_monitor",
":simple_loader",
"//tensorflow_serving/core/test_util:manager_test_util",
"//tensorflow_serving/core/test_util:test_main",
"//tensorflow_serving/util:any_ptr",
"//tensorflow_serving/util:event_bus",
"//tensorflow_serving/util:optional",
"//tensorflow_serving/util:threadpool_executor",
],
)

Expand Down
21 changes: 21 additions & 0 deletions tensorflow_serving/core/basic_manager.h
Original file line number Diff line number Diff line change
Expand Up @@ -168,6 +168,15 @@ class BasicManager : public Manager {
std::vector<ServableStateSnapshot<T>> GetManagedServableStateSnapshots(
const string& servable_name) const;

// Returns the state snapshot of a particular servable-id managed by this
// manager if available.
//
// REQUIRES: This manager should have been managing this servable already,
// else we return nullopt.
template <typename T = std::nullptr_t>
optional<ServableStateSnapshot<T>> GetManagedServableStateSnapshot(
const ServableId& id);

// Returns the additional state for the servable. Returns nullptr if there is
// no additional state setup or if there is a type mismatch between what was
// setup and what is being asked for.
Expand Down Expand Up @@ -464,6 +473,18 @@ BasicManager::GetManagedServableStateSnapshots(
return state_snapshots;
}

template <typename T>
optional<ServableStateSnapshot<T>>
BasicManager::GetManagedServableStateSnapshot(const ServableId& id) {
mutex_lock l(mu_);

auto iter = FindHarnessInMap(id);
if (iter == managed_map_.end()) {
return nullopt;
}
return iter->second->loader_state_snapshot<T>();
}

template <typename T>
T* BasicManager::GetAdditionalServableState(const ServableId& id) {
mutex_lock l(mu_);
Expand Down
29 changes: 29 additions & 0 deletions tensorflow_serving/core/basic_manager_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -353,6 +353,35 @@ TEST_P(BasicManagerTest,
UnorderedElementsAreArray(expected));
}

TEST_P(BasicManagerTest, GetManagedServableStateSnapshot) {
// Check servable state snapshot corresponding to a servable-id that is in
// ready state.
const ServableId id_ready = {kServableName, 1};
const optional<ServableStateSnapshot<>> actual_ready_snapshot =
basic_manager_->GetManagedServableStateSnapshot(id_ready);
EXPECT_TRUE(actual_ready_snapshot);
const ServableStateSnapshot<> expected_ready_snapshot = {
id_ready, LoaderHarness::State::kReady, {}};
EXPECT_EQ(actual_ready_snapshot, expected_ready_snapshot);

// Check servable state snapshot corresponding to a servable-id that is in
// error state.
const ServableId id_error = {kServableName, 7};
basic_manager_->ManageServable(ServableData<std::unique_ptr<Loader>>(
id_error, errors::Internal("An error.")));
const optional<ServableStateSnapshot<>> actual_error_snapshot =
basic_manager_->GetManagedServableStateSnapshot(id_error);
EXPECT_TRUE(actual_error_snapshot);
const ServableStateSnapshot<> expected_error_snapshot = {
id_error, LoaderHarness::State::kError, {}};
EXPECT_EQ(actual_error_snapshot, expected_error_snapshot);

// Check servable state snapshot corresponding to a servable-id that is not
// managed by the basic-manager.
const ServableId id_notmanaged = {kServableName, 8};
EXPECT_FALSE(basic_manager_->GetManagedServableStateSnapshot(id_notmanaged));
}

TEST_P(BasicManagerTest, GetManagedServableStateSnapshotsWithAdditionalState) {
basic_manager_->ManageServableWithAdditionalState(
CreateServable({kServableName3, 0}), std::unique_ptr<int>(new int(0)));
Expand Down
58 changes: 37 additions & 21 deletions tensorflow_serving/core/caching_manager.cc
Original file line number Diff line number Diff line change
Expand Up @@ -100,48 +100,46 @@ Status CachingManager::GetUntypedServableHandleForId(
// requests enforces that exactly one thread performs the load operation with
// the wrapped basic-manager. All other requests block until the load
// completes and then trivially succeed.
LoadServable(servable_id);
TF_RETURN_IF_ERROR(LoadServable(servable_id));

// Return the handle using the loaded servable data now.
return basic_manager_->GetUntypedServableHandle(
ServableRequest::FromId(servable_id), handle);
}

Status CachingManager::LoadServable(const ServableId& servable_id) {
mutex* servable_id_mu;
std::shared_ptr<mutex> servable_id_mu;
{
mutex_lock l(load_mutex_map_mu_);
auto iter = load_mutex_map_.find(servable_id);
if (iter == load_mutex_map_.end()) {
iter = load_mutex_map_
.emplace(servable_id, std::unique_ptr<mutex>(new mutex))
.first;
iter =
load_mutex_map_.emplace(servable_id, std::make_shared<mutex>()).first;
}
servable_id_mu = iter->second.get();
servable_id_mu = iter->second;
}

{
// Ensure only one thread attempts to load the servable at a time.
mutex_lock l(*servable_id_mu);

// Retrieve the state of the servable from the wrapped basic-manager. The
// servable should already be managed by the basic-manager. If a snapshot
// corresponding to the managed servable-id is not found, it is considered
// an error.
// TODO(b/28617799): Update to use the basic-manager API to get the
// servable state snapshot per servable id.
LoaderHarness::State snapshot_state = LoaderHarness::State::kError;
const std::vector<ServableStateSnapshot<>> snapshots =
basic_manager_->GetManagedServableStateSnapshots(servable_id.name);
for (const auto& snapshot : snapshots) {
if (snapshot.id.version == servable_id.version) {
snapshot_state = snapshot.state;
break;
}
// servable should already be managed by the basic-manager.
const optional<ServableStateSnapshot<>> snapshot =
basic_manager_->GetManagedServableStateSnapshot(servable_id);
// If no snapshot is found, the requested servable is not being managed by
// the wrapped basic-manager yet. This is a broken invariant since we expect
// ManageServable() to have been invoked just before calling this method.
// Return an error accordingly.
if (!snapshot) {
const string error_msg = strings::StrCat(
"Servable requested for load is not being managed by the manager: ",
servable_id.DebugString());
DCHECK(false) << error_msg;
return errors::Internal(error_msg);
}

// Load the servable since it has not been loaded yet based on its state.
if (snapshot_state == LoaderHarness::State::kNew) {
if (snapshot.value().state == LoaderHarness::State::kNew) {
Notification load_done;
Status load_status;
basic_manager_->LoadServable(servable_id, [&](const Status& status) {
Expand All @@ -152,9 +150,27 @@ Status CachingManager::LoadServable(const ServableId& servable_id) {
TF_RETURN_IF_ERROR(load_status);
}
}
servable_id_mu.reset();
MaybeEraseLoadMutexMapEntry(servable_id);
return Status::OK();
}

void CachingManager::MaybeEraseLoadMutexMapEntry(
const ServableId& servable_id) {
mutex_lock l(load_mutex_map_mu_);
auto iter = load_mutex_map_.find(servable_id);
// Erase the entry from the map if one exists and if the mutex shared_ptr
// is the last remaining one.
if (iter != load_mutex_map_.end() && iter->second.unique()) {
load_mutex_map_.erase(iter);
}
}

int64 CachingManager::GetLoadMutexMapSize() const {
mutex_lock l(load_mutex_map_mu_);
return load_mutex_map_.size();
}

std::map<ServableId, std::unique_ptr<UntypedServableHandle>>
CachingManager::GetAvailableUntypedServableHandles() const {
return basic_manager_->GetAvailableUntypedServableHandles();
Expand Down
20 changes: 17 additions & 3 deletions tensorflow_serving/core/caching_manager.h
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,10 @@ limitations under the License.
namespace tensorflow {
namespace serving {

namespace test_util {
class CachingManagerTestAccess;
} // namespace test_util

// A manager that manages and loads servables on-demand. Upon receiving the
// request for a servable name and optional version, the manager checks if it
// already has the requested servable loaded. If not, it initiates the load
Expand Down Expand Up @@ -97,6 +101,8 @@ class CachingManager : public Manager {
std::vector<ServableId> ListAvailableServableIds() const override;

private:
friend class test_util::CachingManagerTestAccess;

CachingManager(std::unique_ptr<LoaderFactory> loader_factory,
std::unique_ptr<BasicManager> basic_manager);

Expand All @@ -122,6 +128,13 @@ class CachingManager : public Manager {
Status LoadServable(const ServableId& servable_id)
LOCKS_EXCLUDED(load_mutex_map_mu_);

// Returns the size of the load_mutex_map_.
int64 GetLoadMutexMapSize() const LOCKS_EXCLUDED(load_mutex_map_mu_);

// Erases the entry from the map corresponding to the servable-id if there is
// only one remaining reference to the mutex.
void MaybeEraseLoadMutexMapEntry(const ServableId& servable_id);

std::unique_ptr<LoaderFactory> loader_factory_;

std::unique_ptr<BasicManager> basic_manager_;
Expand All @@ -130,9 +143,10 @@ class CachingManager : public Manager {
mutable mutex load_mutex_map_mu_;

// Map of servable-id to a mutex, which is required to synchronize calls to
// load the servable using the wrapped basic-manager.
// TODO(b/28445976): Add support for garbage-collection of map entries.
std::map<ServableId, std::unique_ptr<mutex>> load_mutex_map_
// load the servable using the wrapped basic-manager. The value in the map is
// a shared_ptr to allow for reference counting and consequent garbage
// collection.
std::map<ServableId, std::shared_ptr<mutex>> load_mutex_map_
GUARDED_BY(load_mutex_map_mu_);

TF_DISALLOW_COPY_AND_ASSIGN(CachingManager);
Expand Down
14 changes: 14 additions & 0 deletions tensorflow_serving/core/caching_manager_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ limitations under the License.
#include "tensorflow_serving/core/servable_state.h"
#include "tensorflow_serving/core/servable_state_monitor.h"
#include "tensorflow_serving/core/simple_loader.h"
#include "tensorflow_serving/core/test_util/manager_test_util.h"
#include "tensorflow_serving/util/event_bus.h"
#include "tensorflow_serving/util/optional.h"
#include "tensorflow_serving/util/threadpool_executor.h"
Expand Down Expand Up @@ -177,6 +178,13 @@ class CachingManagerTest : public ::testing::TestWithParam<int> {
return error_manager;
}

// Helper function to return the size of the load-mutex map from the
// caching-manager.
int64 GetLoadMutexMapSize() {
return test_util::CachingManagerTestAccess(manager_.get())
.GetLoadMutexMapSize();
}

std::shared_ptr<EventBus<ServableState>> servable_event_bus_;
ServableStateMonitor servable_state_monitor_;
std::unique_ptr<CachingManager> manager_;
Expand Down Expand Up @@ -457,6 +465,9 @@ TEST_P(CachingManagerTest, ConcurrentDisjointRequests) {
{kServableName, 32},
{kServableName, 33}};
EXPECT_THAT(actual_keys, UnorderedElementsAreArray(expected_keys));
// Since the map entries in load_mutex_map_ are garbage-collected, we expect
// no remaining entries in the map.
EXPECT_EQ(0, GetLoadMutexMapSize());
}

TEST_P(CachingManagerTest, ConcurrentIntersectingRequests) {
Expand Down Expand Up @@ -494,6 +505,9 @@ TEST_P(CachingManagerTest, ConcurrentIntersectingRequests) {
const std::vector<ServableId> expected_keys = {{kServableName, 30},
{kServableName, 31}};
EXPECT_THAT(actual_keys, UnorderedElementsAreArray(expected_keys));
// Since the map entries in load_mutex_map_ are garbage-collected, we expect
// no remaining entries in the map.
EXPECT_EQ(0, GetLoadMutexMapSize());
}

///////////////////////////////////////////////////////////////////////////////
Expand Down
1 change: 1 addition & 0 deletions tensorflow_serving/core/test_util/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -107,5 +107,6 @@ cc_library(
hdrs = ["manager_test_util.h"],
deps = [
"//tensorflow_serving/core:aspired_versions_manager",
"//tensorflow_serving/core:caching_manager",
],
)
4 changes: 4 additions & 0 deletions tensorflow_serving/core/test_util/manager_test_util.cc
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,10 @@ void AspiredVersionsManagerTestAccess::RunManageState() {
manager_->ManageState();
}

int64 CachingManagerTestAccess::GetLoadMutexMapSize() const {
return manager_->GetLoadMutexMapSize();
}

} // namespace test_util
} // namespace serving
} // namespace tensorflow
17 changes: 17 additions & 0 deletions tensorflow_serving/core/test_util/manager_test_util.h
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ limitations under the License.
#define TENSORFLOW_SERVING_CORE_TEST_UTIL_MANAGER_TEST_UTIL_H_

#include "tensorflow_serving/core/aspired_versions_manager.h"
#include "tensorflow_serving/core/caching_manager.h"

namespace tensorflow {
namespace serving {
Expand All @@ -37,6 +38,22 @@ class AspiredVersionsManagerTestAccess {
TF_DISALLOW_COPY_AND_ASSIGN(AspiredVersionsManagerTestAccess);
};

// A test utility that provides access to private CachingManager members.
class CachingManagerTestAccess {
public:
explicit CachingManagerTestAccess(CachingManager* manager)
: manager_(manager) {}

// Returns the size of the load-mutex map that stores the mutex reference per
// servable-id requested for load.
int64 GetLoadMutexMapSize() const;

private:
CachingManager* const manager_;

TF_DISALLOW_COPY_AND_ASSIGN(CachingManagerTestAccess);
};

} // namespace test_util
} // namespace serving
} // namespace tensorflow
Expand Down
39 changes: 39 additions & 0 deletions tensorflow_serving/session_bundle/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -26,17 +26,36 @@ filegroup(
),
)

py_library(
name = "constants",
srcs = ["constants.py"],
)

py_library(
name = "exporter",
srcs = ["exporter.py"],
visibility = ["//visibility:public"],
deps = [
"@tf//tensorflow:tensorflow_py",
":constants",
":gc",
":manifest_proto_py",
],
)

py_library(
name = "session_bundle_py",
srcs = ["session_bundle.py"],
visibility = ["//visibility:public"],
deps = [
"@tf//tensorflow:tensorflow_py",
"@tf//tensorflow/core:protos_all_py",
":constants",
":exporter",
":manifest_proto_py",
],
)

py_test(
name = "exporter_test",
size = "small",
Expand All @@ -45,12 +64,32 @@ py_test(
],
deps = [
"@tf//tensorflow:tensorflow_py",
":constants",
":exporter",
":gc",
":manifest_proto_py",
],
)

py_test(
name = "session_bundle_py_test",
size = "small",
srcs = [
"session_bundle_test.py",
],
data = [
"//tensorflow_serving/session_bundle/example:half_plus_two",
],
main = "session_bundle_test.py",
deps = [
"@tf//google/protobuf:protobuf_python",
"@tf//tensorflow:tensorflow_py",
"@tf//tensorflow/core:protos_all_py",
":manifest_proto_py",
":session_bundle_py",
],
)

py_library(
name = "gc",
srcs = ["gc.py"],
Expand Down

0 comments on commit b4e9815

Please sign in to comment.