diff --git a/tensorflow/cc/saved_model/BUILD b/tensorflow/cc/saved_model/BUILD index 7dfef7fa3e3362..2d06f4be7267be 100644 --- a/tensorflow/cc/saved_model/BUILD +++ b/tensorflow/cc/saved_model/BUILD @@ -386,6 +386,7 @@ cc_library( "@com_google_protobuf//:protobuf_headers", "//tensorflow/core/grappler:op_types", "@com_google_absl//absl/container:btree", + "@com_google_absl//absl/strings", ] + if_not_mobile(["//tensorflow/core:lib"]) + if_android(["//tensorflow/core:portable_tensorflow_lib_lite"]), alwayslink = True, ) diff --git a/tensorflow/cc/saved_model/fingerprinting.cc b/tensorflow/cc/saved_model/fingerprinting.cc index 228f7238cb8fea..8e71ac6035efd2 100644 --- a/tensorflow/cc/saved_model/fingerprinting.cc +++ b/tensorflow/cc/saved_model/fingerprinting.cc @@ -20,17 +20,24 @@ limitations under the License. #include #include "absl/container/btree_map.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/str_split.h" +#include "absl/strings/strip.h" #include "tensorflow/core/framework/attr_value.pb.h" #include "tensorflow/core/framework/function.pb.h" #include "tensorflow/core/framework/op_def.pb.h" #include "tensorflow/core/framework/types.pb.h" #include "tensorflow/core/framework/versions.pb.h" #include "tensorflow/core/grappler/op_types.h" +#include "tensorflow/core/lib/strings/numbers.h" #include "tensorflow/core/lib/strings/proto_serialization.h" +#include "tensorflow/core/platform/errors.h" #include "tensorflow/core/platform/fingerprint.h" +#include "tensorflow/core/platform/statusor.h" #include "tensorflow/core/protobuf/fingerprint.pb.h" #include "tensorflow/core/protobuf/meta_graph.pb.h" #include "tensorflow/core/protobuf/saved_model.pb.h" +#include "tensorflow/core/protobuf/saved_object_graph.pb.h" namespace tensorflow::fingerprinting { @@ -61,6 +68,17 @@ void CanonicalizeNodes(GraphDef* orig_graph_def) { } } +// Returns the suffix UID of `function_name`. +StatusOr GetSuffixUID(absl::string_view function_name) { + std::vector v = absl::StrSplit(function_name, '_'); + int uid; + if (!strings::safe_strto32(v.back(), &uid)) { + return errors::InvalidArgument(absl::StrCat( + "Function name: `", function_name, "` does not end in an integer.")); + } + return uid; +} + } // namespace uint64 ComputeHash(const GraphDef& graph_def) { @@ -84,6 +102,11 @@ FingerprintDef CreateFingerprintDef(const MetaGraphDef& metagraph) { // Set fingerprint field #3. fingerprint_def.set_signature_def_hash( RegularizeAndHashSignatureDefs(metagraph_copy.signature_def())); + // Set fingerprint field #4. + StatusOr object_graph_hash = + RegularizeAndHashSavedObjectGraph(metagraph_copy.object_graph_def()); + fingerprint_def.set_saved_object_graph_hash( + RegularizeAndHashSavedObjectGraph(metagraph_copy.object_graph_def())); return fingerprint_def; } @@ -114,4 +137,40 @@ uint64 RegularizeAndHashSignatureDefs( return result_hash; } +// The SavedObjectGraph contains two parts: the list of nodes and the map of +// concrete functions. Regularization treats these two parts separately. +uint64 RegularizeAndHashSavedObjectGraph( + const SavedObjectGraph& object_graph_def) { + // Sort `concrete_functions`, which is an unordered map from function names to + // SavedConcreteFunction, using the suffix UID of the function name. Assumes + // that the trackable children are listed in a deterministic order during + // serialization. + absl::btree_map uid_to_function_names; + for (const auto& [name, concrete_function] : + object_graph_def.concrete_functions()) { + StatusOr uid = GetSuffixUID(name); + // All valid function names should end in an UID. + if (uid.ok()) { + uid_to_function_names.insert({*uid, name}); + } else { + LOG(ERROR) << uid.status().error_message(); + } + } + uint64 result_hash = 0; + for (const auto& [uid, function_name] : uid_to_function_names) { + // Hash the function name (with the UID stripped). + result_hash = FingerprintCat64(result_hash, + tensorflow::Fingerprint64(absl::StripSuffix( + function_name, std::to_string(uid)))); + // Hash the serialized concrete function. + std::string concrete_function_string; + SerializeToStringDeterministic( + object_graph_def.concrete_functions().at(function_name), + &concrete_function_string); + result_hash = FingerprintCat64( + result_hash, tensorflow::Fingerprint64(concrete_function_string)); + } + // TODO(b/241294832): Complete canonicalization of `object_graph_def.nodes`. + return result_hash; +} } // namespace tensorflow::fingerprinting diff --git a/tensorflow/cc/saved_model/fingerprinting.h b/tensorflow/cc/saved_model/fingerprinting.h index 0f452c2ae8ccf3..2bedf67b5dce12 100644 --- a/tensorflow/cc/saved_model/fingerprinting.h +++ b/tensorflow/cc/saved_model/fingerprinting.h @@ -33,6 +33,10 @@ uint64 ComputeHash(const GraphDef& graph_def); uint64 RegularizeAndHashSignatureDefs( const google::protobuf::Map& signature_def_map); +// Canonicalizes and computes the Fingerprint64 hash of the SavedObjectGraph. +uint64 RegularizeAndHashSavedObjectGraph( + const SavedObjectGraph& object_graph_def); + // Creates a FingerprintDef proto from a MetaGraph. FingerprintDef CreateFingerprintDef(const MetaGraphDef& metagraph); diff --git a/tensorflow/cc/saved_model/fingerprinting_test.cc b/tensorflow/cc/saved_model/fingerprinting_test.cc index 3dde8305420941..7c1b9ab1facb1d 100644 --- a/tensorflow/cc/saved_model/fingerprinting_test.cc +++ b/tensorflow/cc/saved_model/fingerprinting_test.cc @@ -71,6 +71,7 @@ TEST(FingerprintingTest, TestCreateFingerprint) { EXPECT_GT(fingerprint_def.graph_def_checksum(), 0); EXPECT_EQ(fingerprint_def.signature_def_hash(), 5693392539583495303); + EXPECT_EQ(fingerprint_def.saved_object_graph_hash(), 3678101440349108924); } // Test that canonicalization returns the same hash for two models saved by @@ -124,6 +125,8 @@ TEST(FingerprintingTest, TestCompareFingerprintForTwoModelSavedTwice) { fingerprint_def2.graph_def_program_hash()); EXPECT_EQ(fingerprint_def.signature_def_hash(), fingerprint_def2.signature_def_hash()); + EXPECT_EQ(fingerprint_def.saved_object_graph_hash(), + fingerprint_def2.saved_object_graph_hash()); } TEST(FingerprintingTest, TestFingerprintComputationDoesNotMutateModel) { diff --git a/tensorflow/core/protobuf/fingerprint.proto b/tensorflow/core/protobuf/fingerprint.proto index 8b02749987db70..69a7eced85a43b 100644 --- a/tensorflow/core/protobuf/fingerprint.proto +++ b/tensorflow/core/protobuf/fingerprint.proto @@ -19,4 +19,6 @@ message FingerprintDef { uint64 graph_def_program_hash = 2; // Hash of the regularized (sorted) SignatureDefs. uint64 signature_def_hash = 3; + // Hash of the regularized SavedObjectGraph. + uint64 saved_object_graph_hash = 4; }