Skip to content

Commit

Permalink
Fixes some corner cases when a non-participating task makes the barri…
Browse files Browse the repository at this point in the history
…er call

PiperOrigin-RevId: 629250068
  • Loading branch information
anshumang authored and tensorflower-gardener committed Apr 30, 2024
1 parent dc0ae9a commit bfae8b9
Show file tree
Hide file tree
Showing 2 changed files with 83 additions and 15 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -1028,8 +1028,37 @@ void CoordinationServiceStandaloneImpl::BarrierAsync(
const CoordinatedTask& task,
const std::vector<CoordinatedTask>& participating_tasks,
StatusCallback done) {
VLOG(3) << "Task " << GetTaskName(task) << "invoked BarrierAsync("
VLOG(3) << "Task " << GetTaskName(task) << " invoked BarrierAsync("
<< barrier_id << ").";

// Check if caller task is participating in the barrier. If not, update
// `barriers_` to cause subsequent calls from the same task and other tasks
// that have already called this instance of the barrier to fail.
const std::string source_task_name = GetTaskName(task);

bool among_participating_tasks =
std::find_if(participating_tasks.begin(), participating_tasks.end(),
[&](const CoordinatedTask& task) {
return GetTaskName(task) == source_task_name;
}) != participating_tasks.end();

if (!participating_tasks.empty() && !among_participating_tasks) {
const std::string task_name = GetTaskName(task);
absl::Status error = MakeCoordinationError(errors::InvalidArgument(
absl::StrCat("A non-participating task (", GetTaskName(task),
") called the barrier: ", barrier_id)));
{
mutex_lock l(state_mu_);
auto pair = barriers_.try_emplace(barrier_id);
auto it = pair.first;
auto* barrier = &it->second;
// Make sure subsequent calls fail and existing waiting tasks receive the
// error.
PassBarrier(barrier_id, error, barrier);
}
done(error);
return;
}
mutex_lock l(state_mu_);
auto pair = barriers_.try_emplace(barrier_id);
auto it = pair.first;
Expand Down Expand Up @@ -1117,16 +1146,6 @@ void CoordinationServiceStandaloneImpl::BarrierAsync(
// Add pending callbacks.
barrier->done_callbacks.push_back(done);

// Check if caller task is participating in the barrier.
if (!barrier->tasks_at_barrier.contains(task)) {
// Unexpected barrier call from a task not participating in the barrier.
absl::Status error = MakeCoordinationError(errors::InvalidArgument(
absl::StrCat("A non-participating task (", GetTaskName(task),
") called the barrier: ", barrier_id)));
PassBarrier(barrier_id, error, barrier);
return;
}

// Check if task args are specified consistently across barrier calls.
if (!ValidateTaskArgs(participating_tasks, barrier->tasks_at_barrier,
cluster_state_.size())) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -104,8 +104,7 @@ class TestCoordinationClient : public CoordinationClient {
#define UNIMPLEMENTED(method) \
void method##Async(const method##Request* request, \
method##Response* response, StatusCallback done) \
override { \
done(errors::Unimplemented(#method "Async")); \
override{done(errors::Unimplemented(#method "Async")); \
}

UNIMPLEMENTED(WaitForAllTasks);
Expand All @@ -123,8 +122,7 @@ class TestCoordinationClient : public CoordinationClient {
#define UNIMPLEMENTED_WITH_CALL_OPTS(method) \
void method##Async(CallOptions* call_opts, const method##Request* request, \
method##Response* response, StatusCallback done) \
override { \
done(errors::Unimplemented(#method "Async")); \
override{done(errors::Unimplemented(#method "Async")); \
}

UNIMPLEMENTED_WITH_CALL_OPTS(GetKeyValue);
Expand Down Expand Up @@ -203,6 +201,14 @@ class CoordinationBarrierTest : public ::testing::Test {
"/task:", task.task_id());
}

std::vector<TestCoordinationClient*> GetClients() {
std::vector<TestCoordinationClient*> clients;
for (const auto& client : clients_) {
clients.push_back(client.get());
}
return clients;
}

private:
std::unique_ptr<CoordinationServiceInterface> coord_service_;
std::vector<CoordinatedTask> tasks_;
Expand Down Expand Up @@ -964,6 +970,49 @@ TEST_F(CoordinationBarrierTest, BarrierByNonParticipatingTask) {
EXPECT_TRUE(absl::IsInvalidArgument(barrier_status_1));
}

TEST_F(CoordinationBarrierTest, BarrierByNonParticipatingTaskThreeTasks) {
const std::string barrier_id = "barrier_id";
absl::Duration timeout = absl::Seconds(5);
absl::Status barrier_status_0;
absl::Status barrier_status_1;
absl::Status barrier_status_2;
absl::Notification n_0;
absl::Notification n_1;

GetCoordinationService()->BarrierAsync(
barrier_id, timeout, GetTask(0),
/*participating_tasks=*/{GetTask(0), GetTask(1)},
[&barrier_status_0, &n_0](absl::Status s) {
barrier_status_0 = s;
n_0.Notify();
});
GetCoordinationService()->BarrierAsync(
barrier_id, timeout, GetTask(1),
/*participating_tasks=*/{GetTask(0), GetTask(1)},
[&barrier_status_1, &n_1](absl::Status s) {
barrier_status_1 = s;
n_1.Notify();
});

n_0.WaitForNotification();
n_1.WaitForNotification();

// Barrier should pass because only participating tasks have called it.
TF_EXPECT_OK(barrier_status_0);
TF_EXPECT_OK(barrier_status_1);

// Task 2 unexpectedly calls a barrier that it is not participating in.
GetCoordinationService()->BarrierAsync(
barrier_id, timeout, GetTask(2),
/*participating_tasks=*/{GetTask(0), GetTask(1)},
[&barrier_status_2](absl::Status s) { barrier_status_2 = s; });

// Barrier should fail for task 2 which is not participating in the barrier.
EXPECT_TRUE(absl::IsInvalidArgument(barrier_status_2));

// Other clients would need to check the barrier key to detect the error.
}

TEST_F(CoordinationBarrierTest, BarrierByNonClusterTask) {
const std::string barrier_id = "barrier_id";
absl::Duration timeout = absl::Seconds(5);
Expand Down

0 comments on commit bfae8b9

Please sign in to comment.