Skip to content

Commit

Permalink
[core] Fix bug in task dependency management for duplicate args (#16365)
Browse files Browse the repository at this point in the history
* Pytest

* Skip on windows

* C++
  • Loading branch information
stephanie-wang committed Jun 22, 2021
1 parent 5efeb53 commit e7b752c
Show file tree
Hide file tree
Showing 5 changed files with 72 additions and 7 deletions.
32 changes: 32 additions & 0 deletions python/ray/tests/test_scheduling.py
Original file line number Diff line number Diff line change
Expand Up @@ -446,6 +446,38 @@ def f(x):
assert object_memory_usage() == 0


@pytest.mark.skipif(sys.platform == "win32", reason="Fails on windows")
def test_many_args(ray_start_cluster):
# This test ensures that a task will run where its task dependencies are
# located, even when those objects are borrowed.
cluster = ray_start_cluster
object_size = int(1e6)

# Disable worker caching so worker leases are not reused, and disable
# inlining of return objects so return objects are always put into Plasma.
for _ in range(4):
cluster.add_node(
num_cpus=1, object_store_memory=(4 * object_size * 25))
ray.init(address=cluster.address)

@ray.remote
def f(i, *args):
print(i)
return

@ray.remote
def put():
return np.zeros(object_size, dtype=np.uint8)

xs = [put.remote() for _ in range(100)]
ray.wait(xs, num_returns=len(xs), fetch_local=False)
tasks = []
for i in range(100):
args = [np.random.choice(xs) for _ in range(25)]
tasks.append(f.remote(i, *args))
ray.get(tasks, timeout=30)


if __name__ == "__main__":
import pytest
sys.exit(pytest.main(["-v", __file__]))
2 changes: 1 addition & 1 deletion src/ray/object_manager/pull_manager.cc
Original file line number Diff line number Diff line change
Expand Up @@ -347,9 +347,9 @@ std::vector<ObjectID> PullManager::CancelPull(uint64_t request_id) {
std::vector<ObjectID> object_ids_to_cancel_subscription;
for (const auto &ref : bundle_it->second.objects) {
auto obj_id = ObjectRefToId(ref);
RAY_LOG(DEBUG) << "Removing an object pull request of id: " << obj_id;
auto it = object_pull_requests_.find(obj_id);
if (it != object_pull_requests_.end()) {
RAY_LOG(DEBUG) << "Removing an object pull request of id: " << obj_id;
it->second.bundle_request_ids.erase(bundle_it->first);
if (it->second.bundle_request_ids.empty()) {
object_pull_requests_.erase(it);
Expand Down
7 changes: 6 additions & 1 deletion src/ray/raylet/dependency_manager.cc
Original file line number Diff line number Diff line change
Expand Up @@ -156,7 +156,10 @@ bool DependencyManager::RequestTaskDependencies(
const TaskID &task_id, const std::vector<rpc::ObjectReference> &required_objects) {
RAY_LOG(DEBUG) << "Adding dependencies for task " << task_id
<< ". Required objects length: " << required_objects.size();
auto inserted = queued_task_requests_.emplace(task_id, required_objects);

const auto required_ids = ObjectRefsToIds(required_objects);
absl::flat_hash_set<ObjectID> deduped_ids(required_ids.begin(), required_ids.end());
auto inserted = queued_task_requests_.emplace(task_id, std::move(deduped_ids));
RAY_CHECK(inserted.second) << "Task depedencies can be requested only once per task. "
<< task_id;
auto &task_entry = inserted.first->second;
Expand All @@ -167,7 +170,9 @@ bool DependencyManager::RequestTaskDependencies(

auto it = GetOrInsertRequiredObject(obj_id, ref);
it->second.dependent_tasks.insert(task_id);
}

for (const auto &obj_id : task_entry.dependencies) {
if (local_objects_.count(obj_id)) {
task_entry.num_missing_dependencies--;
}
Expand Down
7 changes: 2 additions & 5 deletions src/ray/raylet/dependency_manager.h
Original file line number Diff line number Diff line change
Expand Up @@ -192,11 +192,8 @@ class DependencyManager : public TaskDependencyManagerInterface {

/// A struct to represent the object dependencies of a task.
struct TaskDependencies {
TaskDependencies(const std::vector<rpc::ObjectReference> &deps)
: num_missing_dependencies(deps.size()) {
const auto dep_ids = ObjectRefsToIds(deps);
dependencies.insert(dep_ids.begin(), dep_ids.end());
}
TaskDependencies(const absl::flat_hash_set<ObjectID> &deps)
: dependencies(std::move(deps)), num_missing_dependencies(dependencies.size()) {}
/// The objects that the task depends on. These are the arguments to the
/// task. These must all be simultaneously local before the task is ready
/// to execute. Objects are removed from this set once
Expand Down
31 changes: 31 additions & 0 deletions src/ray/raylet/dependency_manager_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -314,6 +314,37 @@ TEST_F(DependencyManagerTest, TestWaitObjectLocal) {
AssertNoLeaks();
}

/// Test requesting the dependencies for a task. The dependency manager should
/// return the task ID as ready once all of its unique arguments are local.
TEST_F(DependencyManagerTest, TestDuplicateTaskArgs) {
// Create a task with 3 arguments.
int num_arguments = 3;
auto obj_id = ObjectID::FromRandom();
std::vector<ObjectID> arguments;
for (int i = 0; i < num_arguments; i++) {
arguments.push_back(obj_id);
}
TaskID task_id = RandomTaskId();
bool ready =
dependency_manager_.RequestTaskDependencies(task_id, ObjectIdsToRefs(arguments));
ASSERT_FALSE(ready);
ASSERT_EQ(object_manager_mock_.active_task_requests.size(), 1);

auto ready_task_ids = dependency_manager_.HandleObjectLocal(obj_id);
ASSERT_EQ(ready_task_ids.size(), 1);
ASSERT_EQ(ready_task_ids.front(), task_id);
dependency_manager_.RemoveTaskDependencies(task_id);

TaskID task_id2 = RandomTaskId();
ready =
dependency_manager_.RequestTaskDependencies(task_id2, ObjectIdsToRefs(arguments));
ASSERT_TRUE(ready);
ASSERT_EQ(object_manager_mock_.active_task_requests.size(), 1);
dependency_manager_.RemoveTaskDependencies(task_id2);

AssertNoLeaks();
}

} // namespace raylet

} // namespace ray
Expand Down

0 comments on commit e7b752c

Please sign in to comment.