Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SavedModelBundleLite] Avoid copying the GraphDef during the load path. #64858

Merged
merged 1 commit into from
Apr 2, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
118 changes: 78 additions & 40 deletions tensorflow/cc/saved_model/loader.cc
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,11 @@ limitations under the License.
#include <memory>
#include <string>
#include <unordered_set>
#include <utility>

#include "absl/status/status.h"
#include "absl/strings/str_cat.h"
#include "absl/strings/str_join.h"
#include "tensorflow/cc/saved_model/constants.h"
#include "tensorflow/cc/saved_model/fingerprinting.h"
#include "tensorflow/cc/saved_model/loader_util.h"
Expand Down Expand Up @@ -280,6 +282,16 @@ Status LoadMetagraphIntoSession(const SessionOptions& session_options,
return (*session)->Create(meta_graph.graph_def());
}

Status LoadGraphDefIntoSession(const SessionOptions& session_options,
GraphDef graph_def,
std::unique_ptr<Session>* session) {
Session* session_p = nullptr;
TF_RETURN_IF_ERROR(NewSession(session_options, &session_p));
session->reset(session_p);
TF_RETURN_IF_ERROR(ValidateSavedTensors(graph_def));
return (*session)->Create(std::move(graph_def));
}

Status LoadSavedModelInternal(const SessionOptions& session_options,
const RunOptions& run_options,
const string& export_dir,
Expand All @@ -296,40 +308,6 @@ Status LoadSavedModelInternal(const SessionOptions& session_options,
return absl::OkStatus();
}

Status LoadSavedModel(const SessionOptions& session_options,
const RunOptions& run_options, const string& export_dir,
const std::unordered_set<string>& tags,
SavedModelBundle* const bundle) {
metrics::SavedModelReadApi(kCCLoadLabel).IncrementBy(1);
auto fingerprint_proto =
saved_model::fingerprinting::ReadSavedModelFingerprint(export_dir);
if (fingerprint_proto.ok()) {
// Set gauge cell with saved_model_checksum.
metrics::SavedModelReadFingerprint().Set(
std::to_string(fingerprint_proto->saved_model_checksum()));
}

// TODO(robson): Add tests for the counters.
const uint64 start_microseconds = Env::Default()->NowMicros();
const Status status = LoadSavedModelInternal(session_options, run_options,
export_dir, tags, bundle);
auto log_and_count = [&](const string& status_str) {
LOG(INFO) << "SavedModel load for tags { " << absl::StrJoin(tags, " ")
<< " }; Status: " << status_str << ": " << status << ". Took "
<< GetLatencyMicroseconds(start_microseconds) << " microseconds.";
load_attempt_count->GetCell(export_dir, status_str)->IncrementBy(1);
};
if (status.ok()) {
log_and_count(kLoadAttemptSuccess);
metrics::SavedModelReadPath().Set(export_dir);
} else {
log_and_count(kLoadAttemptFail);
}
load_latency->GetCell(export_dir)
->IncrementBy(GetLatencyMicroseconds(start_microseconds));
return status;
}

namespace {
// Session wrapper that prevents calls to Session::Create(), Session::Extend(),
// and the deprecated partial-run methods.
Expand Down Expand Up @@ -441,6 +419,70 @@ class LiteSessionWrapper : public Session {
};
} // namespace

Status LoadSavedModelInternal(const SessionOptions& session_options,
const RunOptions& run_options,
const string& export_dir,
const std::unordered_set<string>& tags,
SavedModelBundleLite* const bundle) {
MetaGraphDef meta_graph_def;
TF_RETURN_IF_ERROR(
ReadMetaGraphDefFromSavedModel(export_dir, tags, &meta_graph_def));
std::unique_ptr<Session> session;
TF_RETURN_IF_ERROR(LoadGraphDefIntoSession(
session_options, std::move(*meta_graph_def.mutable_graph_def()),
&session));
TF_RETURN_IF_ERROR(
RestoreSession(run_options, meta_graph_def, export_dir, &session));
*bundle = SavedModelBundleLite(
std::make_unique<LiteSessionWrapper>(std::move(session)),
std::move(*meta_graph_def.mutable_signature_def()));
return absl::OkStatus();
}

template <typename BundleType>
Status LoadSavedModelGeneric(const SessionOptions& session_options,
const RunOptions& run_options,
const string& export_dir,
const std::unordered_set<string>& tags,
BundleType* const bundle) {
metrics::SavedModelReadApi(kCCLoadLabel).IncrementBy(1);
auto fingerprint_proto =
saved_model::fingerprinting::ReadSavedModelFingerprint(export_dir);
if (fingerprint_proto.ok()) {
// Set gauge cell with saved_model_checksum.
metrics::SavedModelReadFingerprint().Set(
std::to_string(fingerprint_proto->saved_model_checksum()));
}

// TODO(robson): Add tests for the counters.
const uint64 start_microseconds = Env::Default()->NowMicros();
const Status status = LoadSavedModelInternal(session_options, run_options,
export_dir, tags, bundle);
auto log_and_count = [&](const string& status_str) {
LOG(INFO) << "SavedModel load for tags { " << absl::StrJoin(tags, " ")
<< " }; Status: " << status_str << ": " << status << ". Took "
<< GetLatencyMicroseconds(start_microseconds) << " microseconds.";
load_attempt_count->GetCell(export_dir, status_str)->IncrementBy(1);
};
if (status.ok()) {
log_and_count(kLoadAttemptSuccess);
metrics::SavedModelReadPath().Set(export_dir);
} else {
log_and_count(kLoadAttemptFail);
}
load_latency->GetCell(export_dir)
->IncrementBy(GetLatencyMicroseconds(start_microseconds));
return status;
}

Status LoadSavedModel(const SessionOptions& session_options,
const RunOptions& run_options, const string& export_dir,
const std::unordered_set<string>& tags,
SavedModelBundle* const bundle) {
return LoadSavedModelGeneric<SavedModelBundle>(session_options, run_options,
export_dir, tags, bundle);
}

Status RestoreSession(const RunOptions& run_options,
const MetaGraphDef& meta_graph, const string& export_dir,
std::unique_ptr<Session>* session) {
Expand Down Expand Up @@ -476,7 +518,6 @@ Status LoadSavedModel(const SessionOptions& session_options,
const RunOptions& run_options, const string& export_dir,
const std::unordered_set<string>& tags,
SavedModelBundleLite* const bundle) {
SavedModelBundle legacy_bundle;
SessionOptions rewritten_options(session_options);
// We disallow calls to Session::Extend() on the returned session, so we can
// reduce memory consumption by not storing the original GraphDef.
Expand All @@ -489,11 +530,8 @@ Status LoadSavedModel(const SessionOptions& session_options,
->set_disable_output_partition_graphs(true);
// TODO(mrry): Consider specializing the session creation to reduce peak
// RAM consumption by using `Session::Create(GraphDef&&)`.
TF_RETURN_IF_ERROR(LoadSavedModel(rewritten_options, run_options, export_dir,
tags, &legacy_bundle));
*bundle = SavedModelBundleLite(
std::make_unique<LiteSessionWrapper>(std::move(legacy_bundle.session)),
std::move(*legacy_bundle.meta_graph_def.mutable_signature_def()));
TF_RETURN_IF_ERROR(LoadSavedModelGeneric(rewritten_options, run_options,
export_dir, tags, bundle));
return absl::OkStatus();
}

Expand Down