@@ -22,6 +22,7 @@ limitations under the License.
2222#include " tensorflow/compiler/jit/xla_compilation_cache.pb.h"
2323#include " tensorflow/compiler/jit/xla_device_compiler_client.h"
2424#include " tensorflow/compiler/tf2xla/xla_compiler.h"
25+ #include " tensorflow/compiler/xla/pjrt/pjrt_client.h"
2526#include " tensorflow/compiler/xla/service/hlo.pb.h"
2627#include " tensorflow/compiler/xla/util.h"
2728#include " tensorflow/core/framework/device.h"
@@ -100,6 +101,10 @@ class DeviceExecutablePersistor {
100101 XlaSerializedCacheKey BuildSerializedCacheKey (
101102 uint64 signature_hash, const xla::HloModuleProto& hlo_module) const ;
102103
104+ XlaSerializedCacheKey BuildSerializedCacheKey (
105+ uint64 signature_hash, const xla::HloModuleProto& hlo_module,
106+ bool compiled_using_pjrt) const ;
107+
103108 // Serializes the signature and its corresponding entry to a proto message.
104109 StatusOr<XlaSerializedCacheEntry> SerializeEntry (
105110 uint64 signature_hash, const XlaCompiler::Options& options,
@@ -154,7 +159,10 @@ std::string DeviceExecutablePersistor<ExecutableType, ClientType>::
154159 key.prefix (), key.prefix ().empty () ? " " : kXlaSerializedCacheKeySeparator ,
155160 key.signature_fingerprint (), kXlaSerializedCacheKeySeparator ,
156161 key.cluster_fingerprint (), kXlaSerializedCacheKeySeparator ,
157- key.device_type ());
162+ key.device_type (),
163+ key.compiled_using_pjrt ()
164+ ? absl::StrCat (kXlaSerializedCacheKeySeparator , " pjrt" )
165+ : " " );
158166}
159167
160168template <typename ExecutableType, typename ClientType>
@@ -165,17 +173,35 @@ std::string DeviceExecutablePersistor<ExecutableType, ClientType>::GetFilePath(
165173 return io::JoinPath (persistent_cache_directory_, file_name);
166174}
167175
176+ template <typename ExecutableType, typename ClientType>
177+ XlaSerializedCacheKey
178+ DeviceExecutablePersistor<ExecutableType, ClientType>::BuildSerializedCacheKey(
179+ uint64 signature_hash, const xla::HloModuleProto& hlo_module,
180+ bool compiled_using_pjrt) const {
181+ XlaSerializedCacheKey key;
182+ key.set_signature_fingerprint (signature_hash);
183+ key.set_cluster_fingerprint (DeterministicProtoHash64 (hlo_module));
184+ key.set_device_type (device_type ().type_string ());
185+ key.set_prefix (persistence_prefix ());
186+ key.set_compiled_using_pjrt (compiled_using_pjrt);
187+ return key;
188+ }
189+
168190template <typename ExecutableType, typename ClientType>
169191XlaSerializedCacheKey
170192DeviceExecutablePersistor<ExecutableType, ClientType>::BuildSerializedCacheKey(
171193 uint64 signature_hash, const xla::HloModuleProto& hlo_module) const {
172- XlaSerializedCacheKey serialized_cache_key;
173- serialized_cache_key.set_signature_fingerprint (signature_hash);
174- serialized_cache_key.set_cluster_fingerprint (
175- DeterministicProtoHash64 (hlo_module));
176- serialized_cache_key.set_device_type (device_type ().type_string ());
177- serialized_cache_key.set_prefix (persistence_prefix ());
178- return serialized_cache_key;
194+ return BuildSerializedCacheKey (signature_hash, hlo_module, false );
195+ }
196+
197+ // This template specialization sets compiled_using_prjt to true in the cache
198+ // key when the template arguments are PjRtLoadedExecutable and PjRtClient.
199+ template <>
200+ inline XlaSerializedCacheKey
201+ DeviceExecutablePersistor<xla::PjRtLoadedExecutable, xla::PjRtClient>::
202+ BuildSerializedCacheKey (uint64 signature_hash,
203+ const xla::HloModuleProto& hlo_module) const {
204+ return BuildSerializedCacheKey (signature_hash, hlo_module, true );
179205}
180206
181207template <typename ExecutableType, typename ClientType>
0 commit comments