Skip to content

Commit 5760ad4

Browse files
swachhandltensorflower-gardener
authored andcommitted
Add a compiled_using_pjrt field to persistent cache key.
This is to differentiate the executables that would be built by XLA and PJRT clients when they are being serialized to disk for persistence. There would be no changes to the filenames of XLA serialized executables. PJRT serialized executables will have '__pjrt' appended to their filenames. PiperOrigin-RevId: 504338570
1 parent f259468 commit 5760ad4

File tree

6 files changed

+248
-72
lines changed

6 files changed

+248
-72
lines changed

tensorflow/compiler/jit/BUILD

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1299,6 +1299,8 @@ cc_library(
12991299
":xla_device_compiler_client",
13001300
"//tensorflow/compiler/tf2xla:xla_compiler",
13011301
"//tensorflow/compiler/xla:util",
1302+
"//tensorflow/compiler/xla/client:local_client",
1303+
"//tensorflow/compiler/xla/pjrt:pjrt_client",
13021304
"//tensorflow/compiler/xla/service:hlo_proto_cc",
13031305
"//tensorflow/core:core_cpu_base",
13041306
"//tensorflow/core:framework",
@@ -1393,6 +1395,7 @@ tf_cc_test(
13931395
deps = [
13941396
":device_compiler_client",
13951397
":device_executable_persistor",
1398+
":pjrt_device_compiler_client",
13961399
":xla_compilation_cache_proto_cc",
13971400
":xla_cpu_device",
13981401
":xla_cpu_jit",
@@ -1403,10 +1406,13 @@ tf_cc_test(
14031406
"//tensorflow/compiler/xla/client:client_library",
14041407
"//tensorflow/compiler/xla/client:executable_build_options",
14051408
"//tensorflow/compiler/xla/client:local_client",
1409+
"//tensorflow/compiler/xla/pjrt:pjrt_client",
1410+
"//tensorflow/compiler/xla/pjrt:tfrt_cpu_pjrt_client",
14061411
"//tensorflow/core:test",
14071412
"//tensorflow/core/platform:errors",
14081413
"//tensorflow/core/platform:status_matchers",
14091414
"//tensorflow/core/platform:statusor",
1415+
"//tensorflow/core/tfrt/common:pjrt_util",
14101416
"@com_google_googletest//:gtest_main",
14111417
],
14121418
)

tensorflow/compiler/jit/device_executable_persistor.h

Lines changed: 34 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@ limitations under the License.
2222
#include "tensorflow/compiler/jit/xla_compilation_cache.pb.h"
2323
#include "tensorflow/compiler/jit/xla_device_compiler_client.h"
2424
#include "tensorflow/compiler/tf2xla/xla_compiler.h"
25+
#include "tensorflow/compiler/xla/pjrt/pjrt_client.h"
2526
#include "tensorflow/compiler/xla/service/hlo.pb.h"
2627
#include "tensorflow/compiler/xla/util.h"
2728
#include "tensorflow/core/framework/device.h"
@@ -100,6 +101,10 @@ class DeviceExecutablePersistor {
100101
XlaSerializedCacheKey BuildSerializedCacheKey(
101102
uint64 signature_hash, const xla::HloModuleProto& hlo_module) const;
102103

104+
XlaSerializedCacheKey BuildSerializedCacheKey(
105+
uint64 signature_hash, const xla::HloModuleProto& hlo_module,
106+
bool compiled_using_pjrt) const;
107+
103108
// Serializes the signature and its corresponding entry to a proto message.
104109
StatusOr<XlaSerializedCacheEntry> SerializeEntry(
105110
uint64 signature_hash, const XlaCompiler::Options& options,
@@ -154,7 +159,10 @@ std::string DeviceExecutablePersistor<ExecutableType, ClientType>::
154159
key.prefix(), key.prefix().empty() ? "" : kXlaSerializedCacheKeySeparator,
155160
key.signature_fingerprint(), kXlaSerializedCacheKeySeparator,
156161
key.cluster_fingerprint(), kXlaSerializedCacheKeySeparator,
157-
key.device_type());
162+
key.device_type(),
163+
key.compiled_using_pjrt()
164+
? absl::StrCat(kXlaSerializedCacheKeySeparator, "pjrt")
165+
: "");
158166
}
159167

160168
template <typename ExecutableType, typename ClientType>
@@ -165,17 +173,35 @@ std::string DeviceExecutablePersistor<ExecutableType, ClientType>::GetFilePath(
165173
return io::JoinPath(persistent_cache_directory_, file_name);
166174
}
167175

176+
template <typename ExecutableType, typename ClientType>
177+
XlaSerializedCacheKey
178+
DeviceExecutablePersistor<ExecutableType, ClientType>::BuildSerializedCacheKey(
179+
uint64 signature_hash, const xla::HloModuleProto& hlo_module,
180+
bool compiled_using_pjrt) const {
181+
XlaSerializedCacheKey key;
182+
key.set_signature_fingerprint(signature_hash);
183+
key.set_cluster_fingerprint(DeterministicProtoHash64(hlo_module));
184+
key.set_device_type(device_type().type_string());
185+
key.set_prefix(persistence_prefix());
186+
key.set_compiled_using_pjrt(compiled_using_pjrt);
187+
return key;
188+
}
189+
168190
template <typename ExecutableType, typename ClientType>
169191
XlaSerializedCacheKey
170192
DeviceExecutablePersistor<ExecutableType, ClientType>::BuildSerializedCacheKey(
171193
uint64 signature_hash, const xla::HloModuleProto& hlo_module) const {
172-
XlaSerializedCacheKey serialized_cache_key;
173-
serialized_cache_key.set_signature_fingerprint(signature_hash);
174-
serialized_cache_key.set_cluster_fingerprint(
175-
DeterministicProtoHash64(hlo_module));
176-
serialized_cache_key.set_device_type(device_type().type_string());
177-
serialized_cache_key.set_prefix(persistence_prefix());
178-
return serialized_cache_key;
194+
return BuildSerializedCacheKey(signature_hash, hlo_module, false);
195+
}
196+
197+
// This template specialization sets compiled_using_prjt to true in the cache
198+
// key when the template arguments are PjRtLoadedExecutable and PjRtClient.
199+
template <>
200+
inline XlaSerializedCacheKey
201+
DeviceExecutablePersistor<xla::PjRtLoadedExecutable, xla::PjRtClient>::
202+
BuildSerializedCacheKey(uint64 signature_hash,
203+
const xla::HloModuleProto& hlo_module) const {
204+
return BuildSerializedCacheKey(signature_hash, hlo_module, true);
179205
}
180206

181207
template <typename ExecutableType, typename ClientType>

0 commit comments

Comments
 (0)