Skip to content

Commit

Permalink
Merge pull request #62307 from tensorflow/r2.15-a1fd78b23b1
Browse files Browse the repository at this point in the history
r2.15 cherry-pick: a1fd78b "Potential fix - try deleting old DeviceCompiler if new PjRtClient found for TPU."
  • Loading branch information
learning-to-play committed Nov 2, 2023
2 parents c5a43fa + d16adad commit cca5fda
Showing 1 changed file with 22 additions and 3 deletions.
25 changes: 22 additions & 3 deletions tensorflow/compiler/jit/xla_platform_info.cc
Original file line number Diff line number Diff line change
Expand Up @@ -277,11 +277,32 @@ Status GetOrCreatePjRtDeviceCompilerAndProfiler(
const auto& device_type = platform_info.device_type();
const std::string& compiler_name =
GetPjRtDeviceCompilerResourceName(device_type);
const std::string& profiler_name =
GetPjRtDeviceCompilationProfilerResourceName(device_type);
bool deleted_old_device_compiler = false;

// Lookup the DeviceCompiler, create one if not found.
Status s = rm->Lookup<PjRtDeviceCompiler>(
rm->default_container(), compiler_name, pjrt_device_compiler);
if (!s.ok()) {
if (s.ok() && device_type == DEVICE_TPU) {
auto* existing_pjrt_client = (*pjrt_device_compiler)->client();
TF_ASSIGN_OR_RETURN(auto* latest_pjrt_client, GetPjRtClient(device_type));

if (existing_pjrt_client != latest_pjrt_client) {
// PjRtClient has changed. Delete the PjRtDeviceCompiler (and the cache
// within) and create a new one.
TF_RETURN_IF_ERROR(rm->Delete<PjRtDeviceCompiler>(rm->default_container(),
compiler_name));
TF_RETURN_IF_ERROR(rm->Delete<DeviceCompilationProfiler>(
rm->default_container(), profiler_name));

deleted_old_device_compiler = true;
}
}

// TODO(b/308698131): Try consolidating all PJRT-related state into one class
// instead of directly storing it in the ResourceMgr.
if (!s.ok() || deleted_old_device_compiler) {
DeviceType compilation_device_type("");
xla::PjRtClient* pjrt_client = nullptr;
TF_RETURN_IF_ERROR(GetCompilationDeviceTypeAndPjRtClient(
Expand All @@ -296,8 +317,6 @@ Status GetOrCreatePjRtDeviceCompilerAndProfiler(
}));
}

const std::string& profiler_name =
GetPjRtDeviceCompilationProfilerResourceName(device_type);
TF_RETURN_IF_ERROR(rm->LookupOrCreate<DeviceCompilationProfiler>(
rm->default_container(), profiler_name, profiler,
[](DeviceCompilationProfiler** profiler) {
Expand Down

0 comments on commit cca5fda

Please sign in to comment.