Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[data] Fix early stop for multiple limit ops. #42958

Merged
merged 6 commits into from
Feb 5, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -178,10 +178,10 @@ def __init__(
assert isinstance(x, PhysicalOperator), x
self._inputs_complete = not input_dependencies
self._target_max_block_size = target_max_block_size
self._dependents_complete = False
self._started = False
self._metrics = OpRuntimeMetrics(self)
self._estimated_output_blocks = None
self._execution_completed = False

def __reduce__(self):
raise ValueError("Operator is not serializable.")
Expand All @@ -207,18 +207,22 @@ def actual_target_max_block_size(self) -> int:
def set_target_max_block_size(self, target_max_block_size: Optional[int]):
self._target_max_block_size = target_max_block_size

def mark_execution_completed(self):
"""Manually mark this operator has completed execution."""
self._execution_completed = True

def completed(self) -> bool:
"""Return True when this operator is completed.

An operator is completed if any of the following conditions are met:
- All upstream operators are completed and all outputs are taken.
- All downstream operators are completed.
An operator is completed the operator has stopped execution and all
outputs are taken.
"""
return (
self._inputs_complete
and self.num_active_tasks() == 0
and not self.has_next()
) or self._dependents_complete
if not self._execution_completed:
if self._inputs_complete and self.num_active_tasks() == 0:
# If all inputs are complete and there are no active tasks,
# then the operator has completed execution.
self._execution_completed = True
return self._execution_completed and not self.has_next()

def get_stats(self) -> StatsDict:
"""Return recorded execution stats for use with DatasetStats."""
Expand Down Expand Up @@ -270,13 +274,6 @@ def should_add_input(self) -> bool:
"""
return True

def need_more_inputs(self) -> bool:
"""Return true if the operator still needs more inputs.

Once this return false, it should never return true again.
"""
return True

def add_input(self, refs: RefBundle, input_index: int) -> None:
"""Called when an upstream result is available.

Expand Down Expand Up @@ -314,13 +311,6 @@ def all_inputs_done(self) -> None:
"""
self._inputs_complete = True

def all_dependents_complete(self) -> None:
"""Called when all downstream operators have completed().

After this is called, the operator is marked as completed.
"""
self._dependents_complete = True

def has_next(self) -> bool:
"""Returns when a downstream output is available.

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,14 +29,11 @@ def __init__(
self._cur_output_bundles = 0
super().__init__(self._name, input_op, target_max_block_size=None)
if self._limit <= 0:
self.all_inputs_done()
self.mark_execution_completed()

def _limit_reached(self) -> bool:
return self._consumed_rows >= self._limit

def need_more_inputs(self) -> bool:
return not self._limit_reached()

def _add_input_inner(self, refs: RefBundle, input_index: int) -> None:
assert not self.completed()
assert input_index == 0, input_index
Expand Down Expand Up @@ -79,7 +76,7 @@ def slice_fn(block, metadata, num_rows) -> Tuple[Block, BlockMetadata]:
)
self._buffer.append(out_refs)
if self._limit_reached():
self.all_inputs_done()
self.mark_execution_completed()

# We cannot estimate if we have only consumed empty blocks
if self._consumed_rows > 0:
Expand Down Expand Up @@ -107,12 +104,14 @@ def get_stats(self) -> StatsDict:
return {self._name: self._output_metadata}

def num_outputs_total(self) -> int:
# Before inputs are completed (either because the limit is reached or
# because the inputs operators are done), we don't know how many output
# Before execution is completed, we don't know how many output
# bundles we will have. We estimate based off the consumption so far.
if self._inputs_complete:
if self._execution_completed:
return self._cur_output_bundles
elif self._estimated_output_blocks is not None:
return self._estimated_output_blocks
else:
return self.input_dependencies[0].num_outputs_total()

def throttling_disabled(self) -> bool:
return True
19 changes: 8 additions & 11 deletions python/ray/data/_internal/execution/streaming_executor_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -173,7 +173,6 @@ def __init__(self, op: PhysicalOperator, inqueues: List[OpBufferQueue]):
self.inputs_done_called = False
# Tracks whether `input_done` is called for each input op.
self.input_done_called = [False] * len(op.input_dependencies)
self.dependents_completed_called = False
# Used for StreamingExecutor to signal exception or end of execution
self._finished: bool = False
self._exception: Optional[Exception] = None
Expand Down Expand Up @@ -479,17 +478,16 @@ def update_operator_states(topology: Topology) -> None:
op_state.inputs_done_called = True

# Traverse the topology in reverse topological order.
# For each op, if all of its downstream operators don't need any more inputs,
# call all_dependents_complete() to also complete this op.
# For each op, if all of its downstream operators have completed.
# call mark_execution_completed() to also complete this op.
for op, op_state in reversed(list(topology.items())):
if op_state.dependents_completed_called:
if op.completed():
continue
dependents_completed = len(op.output_dependencies) > 0 and all(
not dep.need_more_inputs() for dep in op.output_dependencies
dep.completed() for dep in op.output_dependencies
)
if dependents_completed:
op.all_dependents_complete()
op_state.dependents_completed_called = True
op.mark_execution_completed()


def select_operator_to_run(
Expand Down Expand Up @@ -518,11 +516,10 @@ def select_operator_to_run(
for op, state in topology.items():
under_resource_limits = _execution_allowed(op, resource_manager)
if (
op.need_more_inputs()
under_resource_limits
and not op.completed()
and state.num_queued() > 0
and op.should_add_input()
and under_resource_limits
and not op.completed()
and all(p.can_add_input(op) for p in backpressure_policies)
):
ops.append(op)
Expand Down Expand Up @@ -551,7 +548,7 @@ def select_operator_to_run(
ops = [
op
for op, state in topology.items()
if op.need_more_inputs() and state.num_queued() > 0 and not op.completed()
if state.num_queued() > 0 and not op.completed()
]

# Nothing to run.
Expand Down
11 changes: 8 additions & 3 deletions python/ray/data/tests/test_consumption.py
Original file line number Diff line number Diff line change
Expand Up @@ -562,15 +562,20 @@ def range_(i):

source = CountingRangeDatasource()

total_rows = 100
parallelism = 10
ds = ray.data.read_datasource(
source,
parallelism=parallelism,
n=10,
n=total_rows // parallelism,
)
ds2 = ds.limit(limit)
# Apply multiple limit ops.
# Once the smallest limit is reached, the entire dataset should stop execution.
ds = ds.limit(total_rows)
ds = ds.limit(limit)
ds = ds.limit(total_rows)
# Check content.
assert extract_values("id", ds2.take(limit)) == list(range(limit))
assert extract_values("id", ds.take(limit)) == list(range(limit))
# Check number of read tasks launched.
# min_read_tasks is the minimum number of read tasks needed for the limit.
# We may launch more tasks than this number, in order to to maximize throughput.
Expand Down
16 changes: 9 additions & 7 deletions python/ray/data/tests/test_operators.py
Original file line number Diff line number Diff line change
Expand Up @@ -614,7 +614,9 @@ def test_limit_operator(ray_start_regular_shared):
refs = make_ref_bundles([[i] * num_rows_per_block for i in range(num_refs)])
input_op = InputDataBuffer(refs)
limit_op = LimitOperator(limit, input_op)
limit_op.all_inputs_done = MagicMock(wraps=limit_op.all_inputs_done)
limit_op.mark_execution_completed = MagicMock(
wraps=limit_op.mark_execution_completed
)
if limit == 0:
# If the limit is 0, the operator should be completed immediately.
assert limit_op.completed()
Expand All @@ -624,24 +626,24 @@ def test_limit_operator(ray_start_regular_shared):
while input_op.has_next() and not limit_op._limit_reached():
loop_count += 1
assert not limit_op.completed(), limit
assert limit_op.need_more_inputs(), limit
assert not limit_op._execution_completed, limit
limit_op.add_input(input_op.get_next(), 0)
while limit_op.has_next():
# Drain the outputs. So the limit operator
# will be completed when the limit is reached.
limit_op.get_next()
cur_rows += num_rows_per_block
if cur_rows >= limit:
assert limit_op.all_inputs_done.call_count == 1, limit
assert limit_op.mark_execution_completed.call_count == 1, limit
assert limit_op.completed(), limit
assert limit_op._limit_reached(), limit
assert not limit_op.need_more_inputs(), limit
assert limit_op._execution_completed, limit
else:
assert limit_op.all_inputs_done.call_count == 0, limit
assert limit_op.mark_execution_completed.call_count == 0, limit
assert not limit_op.completed(), limit
assert not limit_op._limit_reached(), limit
assert limit_op.need_more_inputs(), limit
limit_op.all_inputs_done()
assert not limit_op._execution_completed, limit
limit_op.mark_execution_completed()
# After inputs done, the number of output bundles
# should be the same as the number of `add_input`s.
assert limit_op.num_outputs_total() == loop_count, limit
Expand Down
23 changes: 16 additions & 7 deletions python/ray/data/tests/test_streaming_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,33 +133,42 @@ def test_process_completed_tasks():
done_task = MetadataOpTask(0, ray.put("done"), done_task_callback)
o2.get_active_tasks = MagicMock(return_value=[sleep_task, done_task])
o2.all_inputs_done = MagicMock()
o1.all_dependents_complete = MagicMock()
o1.mark_execution_completed = MagicMock()
process_completed_tasks(topo, [], 0)
update_operator_states(topo)
done_task_callback.assert_called_once()
o2.all_inputs_done.assert_not_called()
o1.all_dependents_complete.assert_not_called()
o1.mark_execution_completed.assert_not_called()

# Test input finalization.
done_task_callback = MagicMock()
done_task = MetadataOpTask(0, ray.put("done"), done_task_callback)
o2.get_active_tasks = MagicMock(return_value=[done_task])
o2.all_inputs_done = MagicMock()
o1.all_dependents_complete = MagicMock()
o1.mark_execution_completed = MagicMock()
o1.completed = MagicMock(return_value=True)
topo[o1].outqueue.clear()
process_completed_tasks(topo, [], 0)
update_operator_states(topo)
done_task_callback.assert_called_once()
o2.all_inputs_done.assert_called_once()
o1.all_dependents_complete.assert_not_called()
o1.mark_execution_completed.assert_not_called()

# Test dependents completed.
o2.need_more_inputs = MagicMock(return_value=False)
o1.all_dependents_complete = MagicMock()
o1 = InputDataBuffer(inputs)
o2 = MapOperator.create(
make_map_transformer(lambda block: [b * -1 for b in block]), o1
)
o3 = MapOperator.create(
make_map_transformer(lambda block: [b * -1 for b in block]), o2
)
topo, _ = build_streaming_topology(o3, ExecutionOptions(verbose_progress=True))

o3.mark_execution_completed()
o2.mark_execution_completed = MagicMock()
process_completed_tasks(topo, [], 0)
update_operator_states(topo)
o1.all_dependents_complete.assert_called_once()
o2.mark_execution_completed.assert_called_once()


def test_select_operator_to_run():
Expand Down