diff --git a/src/ray/raylet/task_dependency_manager.cc b/src/ray/raylet/task_dependency_manager.cc index e0d447824276d..fe4364c4491f4 100644 --- a/src/ray/raylet/task_dependency_manager.cc +++ b/src/ray/raylet/task_dependency_manager.cc @@ -304,28 +304,35 @@ void TaskDependencyManager::TaskCanceled(const TaskID &task_id) { void TaskDependencyManager::RemoveTasksAndRelatedObjects( const std::unordered_set &task_ids) { - if (task_ids.empty()) { - return; - } - + // Collect a list of all the unique objects that these tasks were subscribed + // to. + std::unordered_set required_objects; for (auto it = task_ids.begin(); it != task_ids.end(); it++) { + auto task_it = task_dependencies_.find(*it); + if (task_it != task_dependencies_.end()) { + // Add the objects that this task was subscribed to. + required_objects.insert(task_it->second.object_dependencies.begin(), + task_it->second.object_dependencies.end()); + } + // The task no longer depends on anything. task_dependencies_.erase(*it); - required_tasks_.erase(*it); + // The task is no longer pending execution. pending_tasks_.erase(*it); } - // TODO: the size of required_objects_ could be large, consider to add - // an index if this turns out to be a perf problem. - for (auto it = required_objects_.begin(); it != required_objects_.end();) { - const auto object_id = *it; + // Cancel all of the objects that were required by the removed tasks. + for (const auto &object_id : required_objects) { TaskID creating_task_id = ComputeTaskId(object_id); - if (task_ids.find(creating_task_id) != task_ids.end()) { - object_manager_.CancelPull(object_id); - reconstruction_policy_.Cancel(object_id); - it = required_objects_.erase(it); - } else { - it++; - } + required_tasks_.erase(creating_task_id); + HandleRemoteDependencyCanceled(object_id); + } + + // Make sure that the tasks in task_ids no longer have tasks dependent on + // them. + for (const auto &task_id : task_ids) { + RAY_CHECK(required_tasks_.find(task_id) == required_tasks_.end()) + << "RemoveTasksAndRelatedObjects was called on" << task_id + << ", but another task depends on it that was not included in the argument"; } } diff --git a/src/ray/raylet/task_dependency_manager.h b/src/ray/raylet/task_dependency_manager.h index bb49f4bc182a3..afb146af6a8f3 100644 --- a/src/ray/raylet/task_dependency_manager.h +++ b/src/ray/raylet/task_dependency_manager.h @@ -106,10 +106,12 @@ class TaskDependencyManager { /// \return Return a vector of TaskIDs for tasks registered as pending. std::vector GetPendingTasks() const; - /// Remove all of the tasks specified, and all the objects created by - /// these tasks from task dependency manager. + /// Remove all of the tasks specified. These tasks will no longer be + /// considered pending and the objects they depend on will no longer be + /// required. /// - /// \param task_ids The collection of task IDs. + /// \param task_ids The collection of task IDs. For a given task in this set, + /// all tasks that depend on the task must also be included in the set. void RemoveTasksAndRelatedObjects(const std::unordered_set &task_ids); /// Returns debug string for class. diff --git a/src/ray/raylet/task_dependency_manager_test.cc b/src/ray/raylet/task_dependency_manager_test.cc index 1e05283172320..f414d74695652 100644 --- a/src/ray/raylet/task_dependency_manager_test.cc +++ b/src/ray/raylet/task_dependency_manager_test.cc @@ -415,6 +415,62 @@ TEST_F(TaskDependencyManagerTest, TestTaskLeaseRenewal) { Run(sleep_time); } +TEST_F(TaskDependencyManagerTest, TestRemoveTasksAndRelatedObjects) { + // Create 3 tasks, each dependent on the previous. The first task has no + // arguments. + int num_tasks = 3; + auto tasks = MakeTaskChain(num_tasks, {}, 1); + // No objects should be remote or canceled since each task depends on a + // locally queued task. + EXPECT_CALL(object_manager_mock_, Pull(_)).Times(0); + EXPECT_CALL(reconstruction_policy_mock_, ListenAndMaybeReconstruct(_)).Times(0); + EXPECT_CALL(object_manager_mock_, CancelPull(_)).Times(0); + EXPECT_CALL(reconstruction_policy_mock_, Cancel(_)).Times(0); + for (const auto &task : tasks) { + // Subscribe to each of the tasks' arguments. + const auto &arguments = task.GetDependencies(); + task_dependency_manager_.SubscribeDependencies(task.GetTaskSpecification().TaskId(), + arguments); + // Mark each task as pending. A lease entry should be added to the GCS for + // each task. + EXPECT_CALL(gcs_mock_, Add(_, task.GetTaskSpecification().TaskId(), _, _)); + task_dependency_manager_.TaskPending(task); + } + + // Simulate executing the first task. This should make the second task + // runnable. + auto task = tasks.front(); + TaskID task_id = task.GetTaskSpecification().TaskId(); + auto return_id = task.GetTaskSpecification().ReturnId(0); + task_dependency_manager_.UnsubscribeDependencies(task_id); + // Simulate the object notifications for the task's return values. + auto ready_tasks = task_dependency_manager_.HandleObjectLocal(return_id); + // The second task should be ready to run. + ASSERT_EQ(ready_tasks.size(), 1); + // Simulate the task finishing execution. + task_dependency_manager_.TaskCanceled(task_id); + + // Remove all tasks from the manager except the first task, which already + // finished executing. + std::unordered_set task_ids; + for (const auto &task : tasks) { + task_ids.insert(task.GetTaskSpecification().TaskId()); + } + task_ids.erase(task_id); + task_dependency_manager_.RemoveTasksAndRelatedObjects(task_ids); + // Simulate evicting the return value of the first task. Make sure that this + // does not return the second task, which should have been removed. + auto waiting_tasks = task_dependency_manager_.HandleObjectMissing(return_id); + ASSERT_TRUE(waiting_tasks.empty()); + + // Simulate the object notifications for the second task's return values. + // Make sure that this does not return the third task, which should have been + // removed. + return_id = tasks[1].GetTaskSpecification().ReturnId(0); + ready_tasks = task_dependency_manager_.HandleObjectLocal(return_id); + ASSERT_TRUE(ready_tasks.empty()); +} + } // namespace raylet } // namespace ray