Skip to content

Commit

Permalink
Disable colocate_predecessor_trees pass
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 621548580
  • Loading branch information
tensorflower-gardener committed Apr 3, 2024
1 parent 4dc5be0 commit b159314
Show file tree
Hide file tree
Showing 3 changed files with 22 additions and 3 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -201,7 +201,8 @@ Status ColocatePredecessorTreesPass::Run(
return absl::OkStatus();
}

REGISTER_OPTIMIZATION(OptimizationPassRegistry::PRE_PLACEMENT, 50,
ColocatePredecessorTreesPass);
// TODO(b/331245915): Fix the regression issue then register the pass again.
// REGISTER_OPTIMIZATION(OptimizationPassRegistry::PRE_PLACEMENT, 50,
// ColocatePredecessorTreesPass);

} // namespace tensorflow
1 change: 1 addition & 0 deletions third_party/xla/xla/python/ifrt_proxy/client/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -347,6 +347,7 @@ cc_library(
"@local_tsl//tsl/concurrency:ref_count",
"@local_tsl//tsl/platform:env",
"@local_tsl//tsl/platform:errors",
"@local_tsl//tsl/platform:platform_port",
"@local_tsl//tsl/platform:status_to_from_proto",
"@local_tsl//tsl/platform:statusor",
],
Expand Down
19 changes: 18 additions & 1 deletion third_party/xla/xla/python/ifrt_proxy/client/executable.cc
Original file line number Diff line number Diff line change
Expand Up @@ -57,10 +57,12 @@
#include "xla/shape_util.h"
#include "xla/xla_data.pb.h"
#include "tsl/concurrency/ref_count.h"
#include "tsl/platform/cpu_info.h"
#include "tsl/platform/env.h"
#include "tsl/platform/errors.h"
#include "tsl/platform/status_to_from_proto.h"
#include "tsl/platform/statusor.h"
#include "tsl/platform/threadpool.h"

namespace xla {
namespace ifrt {
Expand Down Expand Up @@ -499,6 +501,17 @@ absl::Span<xla::ifrt::Device* const> LoadedExecutable::addressable_devices()
return addressable_devices_;
}

namespace {

static tsl::ThreadOptions GetThreadOptions() {
tsl::ThreadOptions thread_options;
// Ensure the threads' stack is large enough for arbitrary Python code.
thread_options.stack_size = 2 * 1024 * 1024; // 2 MiB
return thread_options;
}

} // namespace

void LoadedExecutable::PollLoadedHostCallback(
uint64_t handle,
tsl::RCReference<xla::ifrt::LoadedHostCallback> loaded_host_callback) {
Expand Down Expand Up @@ -554,7 +567,11 @@ void LoadedExecutable::PollLoadedHostCallback(
});
}
};
tsl::Env::Default()->SchedClosure(std::move(f));

static auto* global_pool = new tsl::thread::ThreadPool(
tsl::Env::Default(), GetThreadOptions(), "XLAIFRTProxy",
std::min(16, tsl::port::MaxParallelism()));
global_pool->Schedule(std::move(f));
}

char LoadedExecutable::ID = 0; // NOLINT
Expand Down

0 comments on commit b159314

Please sign in to comment.