Skip to content

Commit

Permalink
Automated Code Change
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 622249641
  • Loading branch information
klucke authored and tensorflower-gardener committed Apr 5, 2024
1 parent dcc58f9 commit fa6d2e3
Show file tree
Hide file tree
Showing 20 changed files with 150 additions and 148 deletions.
24 changes: 12 additions & 12 deletions third_party/xla/xla/pjrt/distributed/client.cc
Expand Up @@ -48,17 +48,17 @@ class DistributedRuntimeCoordinationServiceClient
: DistributedRuntimeCoordinationServiceClient(channel, Options()) {}
~DistributedRuntimeCoordinationServiceClient() override;

xla::Status Connect() override;
xla::Status Shutdown() override;
absl::Status Connect() override;
absl::Status Shutdown() override;
absl::StatusOr<std::string> BlockingKeyValueGet(
std::string_view key, absl::Duration timeout) override;
absl::StatusOr<std::vector<std::pair<std::string, std::string>>>
KeyValueDirGet(std::string_view key) override;
xla::Status KeyValueSet(std::string_view key,
std::string_view value) override;
xla::Status KeyValueDelete(std::string_view key) override;
xla::Status WaitAtBarrier(std::string barrier_id,
absl::Duration timeout) override;
absl::Status KeyValueSet(std::string_view key,
std::string_view value) override;
absl::Status KeyValueDelete(std::string_view key) override;
absl::Status WaitAtBarrier(std::string barrier_id,
absl::Duration timeout) override;
absl::StatusOr<tsl::CoordinationServiceAgent*> GetCoordinationServiceAgent()
override;

Expand Down Expand Up @@ -107,7 +107,7 @@ DistributedRuntimeCoordinationServiceClient::
DistributedRuntimeCoordinationServiceClient::
~DistributedRuntimeCoordinationServiceClient() = default;

xla::Status DistributedRuntimeCoordinationServiceClient::Connect() {
absl::Status DistributedRuntimeCoordinationServiceClient::Connect() {
const absl::Time deadline =
absl::Now() +
absl::Milliseconds(config_.cluster_register_timeout_in_ms());
Expand All @@ -130,7 +130,7 @@ xla::Status DistributedRuntimeCoordinationServiceClient::Connect() {
return s;
}

xla::Status DistributedRuntimeCoordinationServiceClient::Shutdown() {
absl::Status DistributedRuntimeCoordinationServiceClient::Shutdown() {
LOG(INFO) << "Distributed task shutdown initiated.";
Status s = coord_agent_->Shutdown();
LOG(INFO) << "Distributed task shutdown result: " << s;
Expand Down Expand Up @@ -162,17 +162,17 @@ DistributedRuntimeCoordinationServiceClient::KeyValueDirGet(
return kvs;
}

xla::Status DistributedRuntimeCoordinationServiceClient::KeyValueDelete(
absl::Status DistributedRuntimeCoordinationServiceClient::KeyValueDelete(
std::string_view key) {
return coord_agent_->DeleteKeyValue(key);
}

xla::Status DistributedRuntimeCoordinationServiceClient::KeyValueSet(
absl::Status DistributedRuntimeCoordinationServiceClient::KeyValueSet(
std::string_view key, std::string_view value) {
return coord_agent_->InsertKeyValue(key, value);
}

xla::Status DistributedRuntimeCoordinationServiceClient::WaitAtBarrier(
absl::Status DistributedRuntimeCoordinationServiceClient::WaitAtBarrier(
std::string barrier_id, absl::Duration timeout) {
return coord_agent_->WaitAtBarrier(barrier_id, timeout, /*tasks=*/{});
}
Expand Down
18 changes: 9 additions & 9 deletions third_party/xla/xla/pjrt/distributed/client.h
Expand Up @@ -71,9 +71,9 @@ class DistributedRuntimeClient {
// is reported by the coordinator, or we have not heard from the coordinator
// recently. `coordinator_reported_failure` is true in the former case.
// Exposed so tests can override this behavior to something non-fatal.
std::function<void(xla::Status, bool coordinator_reported_failure)>
std::function<void(absl::Status, bool coordinator_reported_failure)>
missed_heartbeat_callback =
[](xla::Status status, bool coordinator_reported_failure) {
[](absl::Status status, bool coordinator_reported_failure) {
if (coordinator_reported_failure) {
LOG(QFATAL)
<< "Terminating process because the coordinator detected "
Expand Down Expand Up @@ -104,12 +104,12 @@ class DistributedRuntimeClient {
// connected.
// Not thread-safe, i.e., calls to Connect()/Shutdown() must be serialized by
// some other means.
virtual xla::Status Connect() = 0;
virtual absl::Status Connect() = 0;

// Reports to the master that the client is ready to shutdown, and blocks
// until all clients are ready to shutdown or the shutdown timeout expires.
// Not thread-safe.
virtual xla::Status Shutdown() = 0;
virtual absl::Status Shutdown() = 0;

// The following APIs are thread-safe.

Expand All @@ -127,17 +127,17 @@ class DistributedRuntimeClient {
virtual absl::StatusOr<std::vector<std::pair<std::string, std::string>>>
KeyValueDirGet(std::string_view key) = 0;

virtual xla::Status KeyValueSet(std::string_view key,
std::string_view value) = 0;
virtual absl::Status KeyValueSet(std::string_view key,
std::string_view value) = 0;

// Delete the key-value. If the key is a directory, recursively clean
// up all key-values under the directory.
virtual xla::Status KeyValueDelete(std::string_view key) = 0;
virtual absl::Status KeyValueDelete(std::string_view key) = 0;

// Blocks until all nodes are at the barrier or the barrier times out.
// `barrier_id` should be unique across barriers.
virtual xla::Status WaitAtBarrier(std::string barrier_id,
absl::Duration timeout) = 0;
virtual absl::Status WaitAtBarrier(std::string barrier_id,
absl::Duration timeout) = 0;

// Returns pointer to coordination service agent, or InternalError if the
// client does not use coordination service.
Expand Down
66 changes: 33 additions & 33 deletions third_party/xla/xla/pjrt/distributed/client_server_test.cc
Expand Up @@ -115,7 +115,7 @@ TEST_F(ClientServerTest, ConnectAndShutdownAreBarriers) {

absl::Barrier barrier(num_nodes);

auto thread_fn = [&](int node_id) -> xla::Status {
auto thread_fn = [&](int node_id) -> absl::Status {
auto client = GetClient(node_id);

// Allow the threads to call Connect one-by-one in order.
Expand Down Expand Up @@ -155,7 +155,7 @@ TEST_F(ClientServerTest, ConnectAndShutdownAreBarriers) {
return OkStatus();
};

std::vector<xla::Status> statuses(num_nodes);
std::vector<absl::Status> statuses(num_nodes);
{
tsl::thread::ThreadPool thread_pool(tsl::Env::Default(), "test_threads",
num_nodes);
Expand Down Expand Up @@ -207,7 +207,7 @@ TEST_F(ClientServerTest, ConnectAndEnumerateDevices) {
// client. This ensures that devices are sent out of turn (compared to their
// node ids).
absl::Notification n;
auto thread0_fn = [&]() -> xla::Status {
auto thread0_fn = [&]() -> absl::Status {
auto client = GetClient(/*node_id=*/0);
GlobalTopologyProto topology;
TF_RETURN_IF_ERROR(client->Connect());
Expand All @@ -234,7 +234,7 @@ TEST_F(ClientServerTest, ConnectAndEnumerateDevices) {
TF_RET_CHECK(value == "value2");
return OkStatus();
};
auto thread1_fn = [&]() -> xla::Status {
auto thread1_fn = [&]() -> absl::Status {
auto client = GetClient(/*node_id=*/1);
GlobalTopologyProto topology;
TF_RETURN_IF_ERROR(client->Connect());
Expand All @@ -261,9 +261,9 @@ TEST_F(ClientServerTest, ConnectAndEnumerateDevices) {
return OkStatus();
};

std::vector<std::function<xla::Status()>> functions = {thread0_fn,
thread1_fn};
std::vector<xla::Status> statuses(functions.size());
std::vector<std::function<absl::Status()>> functions = {thread0_fn,
thread1_fn};
std::vector<absl::Status> statuses(functions.size());
{
tsl::thread::ThreadPool thread_pool(tsl::Env::Default(), "test_threads",
functions.size());
Expand Down Expand Up @@ -299,7 +299,7 @@ TEST_F(ClientServerTest, EnumerateElevenDevices) {
node->mutable_devices(0)->set_slice_index(i % 2);
}

auto thread_fn = [&](int node_id) -> xla::Status {
auto thread_fn = [&](int node_id) -> absl::Status {
auto client = GetClient(node_id);
GlobalTopologyProto topology;
TF_RETURN_IF_ERROR(client->Connect());
Expand All @@ -315,7 +315,7 @@ TEST_F(ClientServerTest, EnumerateElevenDevices) {
return OkStatus();
};

std::vector<xla::Status> statuses(num_nodes);
std::vector<absl::Status> statuses(num_nodes);
{
tsl::thread::ThreadPool thread_pool(tsl::Env::Default(), "test_threads",
num_nodes);
Expand All @@ -336,7 +336,7 @@ TEST_F(ClientServerTest, ZeroInitTimeoutShouldStillWaitForOtherTasks) {

absl::Barrier barrier(num_nodes);

auto thread_fn = [&](int node_id) -> xla::Status {
auto thread_fn = [&](int node_id) -> absl::Status {
DistributedRuntimeClient::Options client_options;
client_options.init_timeout = absl::ZeroDuration();
auto client = GetClient(node_id, client_options);
Expand All @@ -351,7 +351,7 @@ TEST_F(ClientServerTest, ZeroInitTimeoutShouldStillWaitForOtherTasks) {
return OkStatus();
};

std::vector<xla::Status> statuses(num_nodes);
std::vector<absl::Status> statuses(num_nodes);
{
tsl::thread::ThreadPool thread_pool(tsl::Env::Default(), "test_threads",
num_nodes);
Expand All @@ -368,11 +368,11 @@ TEST_F(ClientServerTest, ClientsTerminateShutdownIfAnyClientGoesAway) {
int num_nodes = 3;
StartService(num_nodes);

auto thread_fn = [&](int node_id) -> xla::Status {
auto thread_fn = [&](int node_id) -> absl::Status {
DistributedRuntimeClient::Options client_options;
client_options.shutdown_on_destruction = node_id != 0;
client_options.missed_heartbeat_callback =
[&](xla::Status status, bool coordinator_initiated) {};
[&](absl::Status status, bool coordinator_initiated) {};
auto client = GetClient(node_id, client_options);

TF_RETURN_IF_ERROR(client->Connect());
Expand All @@ -387,7 +387,7 @@ TEST_F(ClientServerTest, ClientsTerminateShutdownIfAnyClientGoesAway) {
return OkStatus();
};

std::vector<xla::Status> statuses(num_nodes);
std::vector<absl::Status> statuses(num_nodes);
{
tsl::thread::ThreadPool thread_pool(tsl::Env::Default(), "test_threads",
num_nodes);
Expand All @@ -413,11 +413,11 @@ TEST_F(ClientServerTest, ClientsReceiveMissedHeartbeatIfAnyClientGoesAway) {
int num_nodes = 3;
StartService(num_nodes);

auto thread_fn = [&](int node_id) -> xla::Status {
auto thread_fn = [&](int node_id) -> absl::Status {
DistributedRuntimeClient::Options client_options;
client_options.shutdown_on_destruction = (node_id != 0);
absl::Notification shutdown;
client_options.missed_heartbeat_callback = [&](xla::Status status,
client_options.missed_heartbeat_callback = [&](absl::Status status,
bool coordinator_initiated) {
shutdown.Notify();
};
Expand All @@ -432,7 +432,7 @@ TEST_F(ClientServerTest, ClientsReceiveMissedHeartbeatIfAnyClientGoesAway) {
return OkStatus();
};

std::vector<xla::Status> statuses(num_nodes);
std::vector<absl::Status> statuses(num_nodes);
{
tsl::thread::ThreadPool thread_pool(tsl::Env::Default(), "test_threads",
num_nodes);
Expand All @@ -456,12 +456,12 @@ TEST_F(ClientServerTest, ClientsTerminateIfServiceGoesAway) {

absl::Barrier barrier(num_nodes + 1);

auto thread_fn = [&](int node_id) -> xla::Status {
auto thread_fn = [&](int node_id) -> absl::Status {
DistributedRuntimeClient::Options client_options;
client_options.rpc_timeout = absl::Seconds(1);
client_options.shutdown_timeout = absl::Seconds(10);
absl::Notification shutdown;
client_options.missed_heartbeat_callback = [&](xla::Status status,
client_options.missed_heartbeat_callback = [&](absl::Status status,
bool coordinator_initiated) {
shutdown.Notify();
};
Expand All @@ -480,7 +480,7 @@ TEST_F(ClientServerTest, ClientsTerminateIfServiceGoesAway) {
return OkStatus();
};

std::vector<xla::Status> statuses(num_nodes);
std::vector<absl::Status> statuses(num_nodes);
{
tsl::thread::ThreadPool thread_pool(tsl::Env::Default(), "test_threads",
num_nodes);
Expand All @@ -502,7 +502,7 @@ TEST_F(ClientServerTest, LateClientsAreOk) {

absl::Barrier barrier(num_nodes);

auto thread_fn = [&](int node_id) -> xla::Status {
auto thread_fn = [&](int node_id) -> absl::Status {
DistributedRuntimeClient::Options client_options;
client_options.init_timeout = absl::Seconds(20);
client_options.rpc_timeout = absl::Milliseconds(200);
Expand All @@ -515,7 +515,7 @@ TEST_F(ClientServerTest, LateClientsAreOk) {
return OkStatus();
};

std::vector<xla::Status> statuses(num_nodes);
std::vector<absl::Status> statuses(num_nodes);
{
tsl::thread::ThreadPool thread_pool(tsl::Env::Default(), "test_threads",
num_nodes);
Expand All @@ -537,13 +537,13 @@ TEST_F(ClientServerTest, ConnectEventuallyTimesOutIfAClientDoesNotShowUp) {
service_options.shutdown_timeout = timeout;
StartService(num_nodes, service_options);

auto thread_fn = [&](int node_id) -> xla::Status {
auto thread_fn = [&](int node_id) -> absl::Status {
DistributedRuntimeClient::Options client_options;
client_options.init_timeout = timeout;
client_options.rpc_timeout = timeout;
// Overwrite the default error callback which invokes LOG(QFATAL).
client_options.missed_heartbeat_callback =
[](xla::Status status, bool coordinator_reported_failure) {
[](absl::Status status, bool coordinator_reported_failure) {
LOG(ERROR) << "Distributed client has missing heartbeats: " << status;
};
auto client = GetClient(node_id, client_options);
Expand All @@ -554,7 +554,7 @@ TEST_F(ClientServerTest, ConnectEventuallyTimesOutIfAClientDoesNotShowUp) {
};

// Note: one fewer thread than 'num_nodes'.
std::vector<xla::Status> statuses(num_nodes - 1);
std::vector<absl::Status> statuses(num_nodes - 1);
{
tsl::thread::ThreadPool thread_pool(tsl::Env::Default(), "test_threads",
num_nodes);
Expand All @@ -571,7 +571,7 @@ TEST_F(ClientServerTest, WaitAtBarrier_Succeed) {
int num_nodes = 2;
StartService(num_nodes);

auto thread_fn = [&](int node_id) -> xla::Status {
auto thread_fn = [&](int node_id) -> absl::Status {
auto client = GetClient(node_id);
TF_RETURN_IF_ERROR(client->Connect());

Expand All @@ -582,7 +582,7 @@ TEST_F(ClientServerTest, WaitAtBarrier_Succeed) {
return xla::OkStatus();
};

std::vector<xla::Status> statuses(num_nodes);
std::vector<absl::Status> statuses(num_nodes);
{
tsl::thread::ThreadPool thread_pool(tsl::Env::Default(), "test_threads",
num_nodes);
Expand All @@ -600,7 +600,7 @@ TEST_F(ClientServerTest, WaitAtBarrier_Timeout) {
StartService(num_nodes);
absl::Notification n;

auto thread_fn = [&](int node_id) -> xla::Status {
auto thread_fn = [&](int node_id) -> absl::Status {
auto client = GetClient(node_id);
TF_RETURN_IF_ERROR(client->Connect());

Expand All @@ -619,7 +619,7 @@ TEST_F(ClientServerTest, WaitAtBarrier_Timeout) {
return xla::OkStatus();
};

std::vector<xla::Status> statuses(num_nodes);
std::vector<absl::Status> statuses(num_nodes);
{
tsl::thread::ThreadPool thread_pool(tsl::Env::Default(), "test_threads",
num_nodes);
Expand All @@ -639,7 +639,7 @@ TEST_F(ClientServerTest, WaitAtBarrier_TimeoutWithDifferentBarrierId) {
int num_nodes = 2;
StartService(num_nodes);

auto thread_fn = [&](int node_id) -> xla::Status {
auto thread_fn = [&](int node_id) -> absl::Status {
auto client = GetClient(node_id);
TF_RETURN_IF_ERROR(client->Connect());

Expand All @@ -655,7 +655,7 @@ TEST_F(ClientServerTest, WaitAtBarrier_TimeoutWithDifferentBarrierId) {
return xla::OkStatus();
};

std::vector<xla::Status> statuses(num_nodes);
std::vector<absl::Status> statuses(num_nodes);
{
tsl::thread::ThreadPool thread_pool(tsl::Env::Default(), "test_threads",
num_nodes);
Expand All @@ -673,7 +673,7 @@ TEST_F(ClientServerTest, WaitAtBarrier_FailWithSameBarrierId) {
int num_nodes = 2;
StartService(num_nodes);

auto thread_fn = [&](int node_id) -> xla::Status {
auto thread_fn = [&](int node_id) -> absl::Status {
auto client = GetClient(node_id);
TF_RETURN_IF_ERROR(client->Connect());

Expand All @@ -684,7 +684,7 @@ TEST_F(ClientServerTest, WaitAtBarrier_FailWithSameBarrierId) {
return xla::OkStatus();
};

std::vector<xla::Status> statuses(num_nodes);
std::vector<absl::Status> statuses(num_nodes);
{
tsl::thread::ThreadPool thread_pool(tsl::Env::Default(), "test_threads",
num_nodes);
Expand Down

0 comments on commit fa6d2e3

Please sign in to comment.