Skip to content

Commit

Permalink
[SavedModel Fingerprinting] Add hash #4 to represent the canonicalize…
Browse files Browse the repository at this point in the history
…d SavedObjectGraph.

This commit only looks at the `concrete_functions` of the SavedObjectGraph, ignoring the `nodes`.

RFC: tensorflow/community#415
PiperOrigin-RevId: 465475648
  • Loading branch information
Monica Song authored and tensorflower-gardener committed Aug 5, 2022
1 parent cf29afd commit c371435
Show file tree
Hide file tree
Showing 5 changed files with 69 additions and 0 deletions.
1 change: 1 addition & 0 deletions tensorflow/cc/saved_model/BUILD
Expand Up @@ -386,6 +386,7 @@ cc_library(
"@com_google_protobuf//:protobuf_headers",
"//tensorflow/core/grappler:op_types",
"@com_google_absl//absl/container:btree",
"@com_google_absl//absl/strings",
] + if_not_mobile(["//tensorflow/core:lib"]) + if_android(["//tensorflow/core:portable_tensorflow_lib_lite"]),
alwayslink = True,
)
Expand Down
59 changes: 59 additions & 0 deletions tensorflow/cc/saved_model/fingerprinting.cc
Expand Up @@ -20,17 +20,24 @@ limitations under the License.
#include <vector>

#include "absl/container/btree_map.h"
#include "absl/strings/str_cat.h"
#include "absl/strings/str_split.h"
#include "absl/strings/strip.h"
#include "tensorflow/core/framework/attr_value.pb.h"
#include "tensorflow/core/framework/function.pb.h"
#include "tensorflow/core/framework/op_def.pb.h"
#include "tensorflow/core/framework/types.pb.h"
#include "tensorflow/core/framework/versions.pb.h"
#include "tensorflow/core/grappler/op_types.h"
#include "tensorflow/core/lib/strings/numbers.h"
#include "tensorflow/core/lib/strings/proto_serialization.h"
#include "tensorflow/core/platform/errors.h"
#include "tensorflow/core/platform/fingerprint.h"
#include "tensorflow/core/platform/statusor.h"
#include "tensorflow/core/protobuf/fingerprint.pb.h"
#include "tensorflow/core/protobuf/meta_graph.pb.h"
#include "tensorflow/core/protobuf/saved_model.pb.h"
#include "tensorflow/core/protobuf/saved_object_graph.pb.h"

namespace tensorflow::fingerprinting {

Expand Down Expand Up @@ -61,6 +68,17 @@ void CanonicalizeNodes(GraphDef* orig_graph_def) {
}
}

// Returns the suffix UID of `function_name`.
StatusOr<int> GetSuffixUID(absl::string_view function_name) {
std::vector<std::string> v = absl::StrSplit(function_name, '_');
int uid;
if (!strings::safe_strto32(v.back(), &uid)) {
return errors::InvalidArgument(absl::StrCat(
"Function name: `", function_name, "` does not end in an integer."));
}
return uid;
}

} // namespace

uint64 ComputeHash(const GraphDef& graph_def) {
Expand All @@ -84,6 +102,11 @@ FingerprintDef CreateFingerprintDef(const MetaGraphDef& metagraph) {
// Set fingerprint field #3.
fingerprint_def.set_signature_def_hash(
RegularizeAndHashSignatureDefs(metagraph_copy.signature_def()));
// Set fingerprint field #4.
StatusOr<uint64> object_graph_hash =
RegularizeAndHashSavedObjectGraph(metagraph_copy.object_graph_def());
fingerprint_def.set_saved_object_graph_hash(
RegularizeAndHashSavedObjectGraph(metagraph_copy.object_graph_def()));
return fingerprint_def;
}

Expand Down Expand Up @@ -114,4 +137,40 @@ uint64 RegularizeAndHashSignatureDefs(
return result_hash;
}

// The SavedObjectGraph contains two parts: the list of nodes and the map of
// concrete functions. Regularization treats these two parts separately.
uint64 RegularizeAndHashSavedObjectGraph(
const SavedObjectGraph& object_graph_def) {
// Sort `concrete_functions`, which is an unordered map from function names to
// SavedConcreteFunction, using the suffix UID of the function name. Assumes
// that the trackable children are listed in a deterministic order during
// serialization.
absl::btree_map<int, std::string> uid_to_function_names;
for (const auto& [name, concrete_function] :
object_graph_def.concrete_functions()) {
StatusOr<int> uid = GetSuffixUID(name);
// All valid function names should end in an UID.
if (uid.ok()) {
uid_to_function_names.insert({*uid, name});
} else {
LOG(ERROR) << uid.status().error_message();
}
}
uint64 result_hash = 0;
for (const auto& [uid, function_name] : uid_to_function_names) {
// Hash the function name (with the UID stripped).
result_hash = FingerprintCat64(result_hash,
tensorflow::Fingerprint64(absl::StripSuffix(
function_name, std::to_string(uid))));
// Hash the serialized concrete function.
std::string concrete_function_string;
SerializeToStringDeterministic(
object_graph_def.concrete_functions().at(function_name),
&concrete_function_string);
result_hash = FingerprintCat64(
result_hash, tensorflow::Fingerprint64(concrete_function_string));
}
// TODO(b/241294832): Complete canonicalization of `object_graph_def.nodes`.
return result_hash;
}
} // namespace tensorflow::fingerprinting
4 changes: 4 additions & 0 deletions tensorflow/cc/saved_model/fingerprinting.h
Expand Up @@ -33,6 +33,10 @@ uint64 ComputeHash(const GraphDef& graph_def);
uint64 RegularizeAndHashSignatureDefs(
const google::protobuf::Map<std::string, SignatureDef>& signature_def_map);

// Canonicalizes and computes the Fingerprint64 hash of the SavedObjectGraph.
uint64 RegularizeAndHashSavedObjectGraph(
const SavedObjectGraph& object_graph_def);

// Creates a FingerprintDef proto from a MetaGraph.
FingerprintDef CreateFingerprintDef(const MetaGraphDef& metagraph);

Expand Down
3 changes: 3 additions & 0 deletions tensorflow/cc/saved_model/fingerprinting_test.cc
Expand Up @@ -71,6 +71,7 @@ TEST(FingerprintingTest, TestCreateFingerprint) {

EXPECT_GT(fingerprint_def.graph_def_checksum(), 0);
EXPECT_EQ(fingerprint_def.signature_def_hash(), 5693392539583495303);
EXPECT_EQ(fingerprint_def.saved_object_graph_hash(), 3678101440349108924);
}

// Test that canonicalization returns the same hash for two models saved by
Expand Down Expand Up @@ -124,6 +125,8 @@ TEST(FingerprintingTest, TestCompareFingerprintForTwoModelSavedTwice) {
fingerprint_def2.graph_def_program_hash());
EXPECT_EQ(fingerprint_def.signature_def_hash(),
fingerprint_def2.signature_def_hash());
EXPECT_EQ(fingerprint_def.saved_object_graph_hash(),
fingerprint_def2.saved_object_graph_hash());
}

TEST(FingerprintingTest, TestFingerprintComputationDoesNotMutateModel) {
Expand Down
2 changes: 2 additions & 0 deletions tensorflow/core/protobuf/fingerprint.proto
Expand Up @@ -19,4 +19,6 @@ message FingerprintDef {
uint64 graph_def_program_hash = 2;
// Hash of the regularized (sorted) SignatureDefs.
uint64 signature_def_hash = 3;
// Hash of the regularized SavedObjectGraph.
uint64 saved_object_graph_hash = 4;
}

0 comments on commit c371435

Please sign in to comment.