Skip to content

Commit

Permalink
Flush all tasks from local lineage cache after a node failure (ray-pr…
Browse files Browse the repository at this point in the history
  • Loading branch information
stephanie-wang authored and robertnishihara committed Jun 12, 2019
1 parent e0e52f1 commit 89ca5ee
Show file tree
Hide file tree
Showing 4 changed files with 80 additions and 1 deletion.
14 changes: 14 additions & 0 deletions src/ray/raylet/lineage_cache.cc
Expand Up @@ -221,6 +221,20 @@ bool LineageCache::CommitTask(const Task &task) {
}
}

void LineageCache::FlushAllUncommittedTasks() {
size_t num_flushed = 0;
for (const auto &entry : lineage_.GetEntries()) {
// Flush all tasks that have not yet committed.
if (entry.second.GetStatus() == GcsStatus::UNCOMMITTED) {
RAY_CHECK(UnsubscribeTask(entry.first));
FlushTask(entry.first);
num_flushed++;
}
}

RAY_LOG(DEBUG) << "Flushed " << num_flushed << " uncommitted tasks";
}

void LineageCache::MarkTaskAsForwarded(const TaskID &task_id, const ClientID &node_id) {
RAY_CHECK(!node_id.IsNil());
auto entry = lineage_.GetEntryMutable(task_id);
Expand Down
7 changes: 7 additions & 0 deletions src/ray/raylet/lineage_cache.h
Expand Up @@ -231,6 +231,13 @@ class LineageCache {
/// task was already in the COMMITTING state.
bool CommitTask(const Task &task);

/// Flush all tasks in the local cache that are not already being
/// committed. This is equivalent to all tasks in the UNCOMMITTED
/// state.
///
/// \return Void.
void FlushAllUncommittedTasks();

/// Add a task and its (estimated) uncommitted lineage to the local cache. We
/// will subscribe to commit notifications for all uncommitted tasks to
/// determine when it is safe to evict the lineage from the local cache.
Expand Down
55 changes: 54 additions & 1 deletion src/ray/raylet/lineage_cache_test.cc
Expand Up @@ -26,8 +26,22 @@ class MockGcs : public gcs::TableInterface<TaskID, protocol::Task>,
std::shared_ptr<protocol::TaskT> &task_data,
const gcs::TableInterface<TaskID, protocol::Task>::WriteCallback &done) {
task_table_[task_id] = task_data;
auto callback = done;
// If we requested notifications for this task ID, send the notification as
// part of the callback.
if (subscribed_tasks_.count(task_id) == 1) {
callback = [this, done](ray::gcs::AsyncGcsClient *client, const TaskID &task_id,
const protocol::TaskT &data) {
done(client, task_id, data);
// If we're subscribed to the task to be added, also send a
// subscription notification.
notification_callback_(client, task_id, data);
};
}

callbacks_.push_back(
std::pair<gcs::raylet::TaskTable::WriteCallback, TaskID>(done, task_id));
std::pair<gcs::raylet::TaskTable::WriteCallback, TaskID>(callback, task_id));
num_task_adds_++;
return ray::Status::OK();
}

Expand Down Expand Up @@ -78,28 +92,34 @@ class MockGcs : public gcs::TableInterface<TaskID, protocol::Task>,

const int NumRequestedNotifications() const { return num_requested_notifications_; }

const int NumTaskAdds() const { return num_task_adds_; }

private:
std::unordered_map<TaskID, std::shared_ptr<protocol::TaskT>> task_table_;
std::vector<std::pair<gcs::raylet::TaskTable::WriteCallback, TaskID>> callbacks_;
gcs::raylet::TaskTable::WriteCallback notification_callback_;
std::unordered_set<TaskID> subscribed_tasks_;
int num_requested_notifications_ = 0;
int num_task_adds_ = 0;
};

class LineageCacheTest : public ::testing::Test {
public:
LineageCacheTest()
: max_lineage_size_(10),
num_notifications_(0),
mock_gcs_(),
lineage_cache_(ClientID::FromRandom(), mock_gcs_, mock_gcs_, max_lineage_size_) {
mock_gcs_.Subscribe([this](ray::gcs::AsyncGcsClient *client, const TaskID &task_id,
const ray::protocol::TaskT &data) {
lineage_cache_.HandleEntryCommitted(task_id);
num_notifications_++;
});
}

protected:
uint64_t max_lineage_size_;
uint64_t num_notifications_;
MockGcs mock_gcs_;
LineageCache lineage_cache_;
};
Expand Down Expand Up @@ -529,6 +549,39 @@ TEST_F(LineageCacheTest, TestEvictionUncommittedChildren) {
ASSERT_EQ(lineage_cache_.GetLineage().GetChildrenSize(), 0);
}

TEST_F(LineageCacheTest, TestFlushAllUncommittedTasks) {
// Insert a chain of tasks.
std::vector<Task> tasks;
auto return_values =
InsertTaskChain(lineage_cache_, tasks, 3, std::vector<ObjectID>(), 1);
std::vector<TaskID> task_ids;
for (const auto &task : tasks) {
task_ids.push_back(task.GetTaskSpecification().TaskId());
}
// Check that we subscribed to each of the uncommitted tasks.
ASSERT_EQ(mock_gcs_.NumRequestedNotifications(), task_ids.size());

// Flush all uncommitted tasks and make sure we add all tasks to
// the task table.
lineage_cache_.FlushAllUncommittedTasks();
ASSERT_EQ(mock_gcs_.NumTaskAdds(), tasks.size());
// Flush again and make sure there are no new tasks added to the
// task table.
lineage_cache_.FlushAllUncommittedTasks();
ASSERT_EQ(mock_gcs_.NumTaskAdds(), tasks.size());

// Flush all GCS notifications.
mock_gcs_.Flush();
// Make sure that we unsubscribed to the uncommitted tasks before
// we flushed them.
ASSERT_EQ(num_notifications_, 0);

// Flush again and make sure there are no new tasks added to the
// task table.
lineage_cache_.FlushAllUncommittedTasks();
ASSERT_EQ(mock_gcs_.NumTaskAdds(), tasks.size());
}

} // namespace raylet

} // namespace ray
Expand Down
5 changes: 5 additions & 0 deletions src/ray/raylet/node_manager.cc
Expand Up @@ -475,6 +475,11 @@ void NodeManager::ClientRemoved(const ClientTableDataT &client_data) {
// Notify the object directory that the client has been removed so that it
// can remove it from any cached locations.
object_directory_->HandleClientRemoved(client_id);

// Flush all uncommitted tasks from the local lineage cache. This is to
// guarantee that all tasks get flushed eventually, in case one of the tasks
// in our local cache was supposed to be flushed by the node that died.
lineage_cache_.FlushAllUncommittedTasks();
}

void NodeManager::ResourceCreateUpdated(const ClientTableDataT &client_data) {
Expand Down

0 comments on commit 89ca5ee

Please sign in to comment.