Skip to content

Commit

Permalink
[data] Inherit block size from downstream ops (#41019)
Browse files Browse the repository at this point in the history
Stacked on #40757

Compute the block size for each operation before applying other optimizer rules that depend on it (SplitReadOutputBlocksRule). This also simplifies the block sizing, so we always propagate an op's target block size to all upstream ops, until we find an op that has a different block size set.
Related issue number

Closes #41018.

---------

Signed-off-by: Stephanie Wang <swang@cs.berkeley.edu>
  • Loading branch information
stephanie-wang committed Nov 29, 2023
1 parent 64e5373 commit 6f7378c
Show file tree
Hide file tree
Showing 11 changed files with 182 additions and 87 deletions.
2 changes: 1 addition & 1 deletion python/ray/data/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -131,7 +131,7 @@ py_test(

py_test(
name = "test_block_sizing",
size = "small",
size = "medium",
srcs = ["tests/test_block_sizing.py"],
tags = ["team:data", "exclusive"],
deps = ["//:ray_lib", ":conftest"],
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -204,6 +204,9 @@ def actual_target_max_block_size(self) -> int:
target_max_block_size = DataContext.get_current().target_max_block_size
return target_max_block_size

def set_target_max_block_size(self, target_max_block_size: Optional[int]):
self._target_max_block_size = target_max_block_size

def completed(self) -> bool:
"""Return True when this operator is completed.
Expand Down
14 changes: 14 additions & 0 deletions python/ray/data/_internal/logical/operators/read_operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,3 +20,17 @@ def __init__(
self._datasource_or_legacy_reader = datasource_or_legacy_reader
self._parallelism = parallelism
self._mem_size = mem_size
self._detected_parallelism = None

def set_detected_parallelism(self, parallelism: int):
"""
Set the true parallelism that should be used during execution. This
should be specified by the user or detected by the optimizer.
"""
self._detected_parallelism = parallelism

def get_detected_parallelism(self) -> int:
"""
Get the true parallelism that should be used during execution.
"""
return self._detected_parallelism
10 changes: 6 additions & 4 deletions python/ray/data/_internal/logical/optimizers.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,11 +10,12 @@
add_user_provided_logical_rules,
add_user_provided_physical_rules,
)
from ray.data._internal.logical.rules.inherit_target_max_block_size import (
InheritTargetMaxBlockSizeRule,
)
from ray.data._internal.logical.rules.operator_fusion import OperatorFusionRule
from ray.data._internal.logical.rules.randomize_blocks import ReorderRandomizeBlocksRule
from ray.data._internal.logical.rules.split_read_output_blocks import (
SplitReadOutputBlocksRule,
)
from ray.data._internal.logical.rules.set_read_parallelism import SetReadParallelismRule
from ray.data._internal.logical.rules.zero_copy_map_fusion import (
EliminateBuildOutputBlocks,
)
Expand All @@ -25,7 +26,8 @@
]

DEFAULT_PHYSICAL_RULES = [
SplitReadOutputBlocksRule,
InheritTargetMaxBlockSizeRule,
SetReadParallelismRule,
OperatorFusionRule,
EliminateBuildOutputBlocks,
]
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
from typing import Optional

from ray.data._internal.execution.interfaces import PhysicalOperator
from ray.data._internal.logical.interfaces import PhysicalPlan, Rule


class InheritTargetMaxBlockSizeRule(Rule):
"""For each op that has overridden the default target max block size,
propagate to upstream ops until we reach an op that has also overridden the
target max block size."""

def apply(self, plan: PhysicalPlan) -> PhysicalPlan:
self._propagate_target_max_block_size_to_upstream_ops(plan.dag)
return plan

def _propagate_target_max_block_size_to_upstream_ops(
self, dag: PhysicalOperator, target_max_block_size: Optional[int] = None
):
if dag.target_max_block_size is not None:
# Set the target block size to inherit for
# upstream ops.
target_max_block_size = dag.target_max_block_size
elif target_max_block_size is not None:
# Inherit from downstream op.
dag.set_target_max_block_size(target_max_block_size)

for upstream_op in dag.input_dependencies:
self._propagate_target_max_block_size_to_upstream_ops(
upstream_op, target_max_block_size
)
41 changes: 4 additions & 37 deletions python/ray/data/_internal/logical/rules/operator_fusion.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@
from ray.data._internal.execution.operators.base_physical_operator import (
AllToAllOperator,
)
from ray.data._internal.execution.operators.input_data_buffer import InputDataBuffer
from ray.data._internal.execution.operators.map_operator import MapOperator
from ray.data._internal.execution.operators.task_pool_map_operator import (
TaskPoolMapOperator,
Expand Down Expand Up @@ -79,8 +78,6 @@ def _fuse_map_operators_in_dag(self, dag: PhysicalOperator) -> MapOperator:
dag = self._get_fused_map_operator(dag, upstream_ops[0])
upstream_ops = dag.input_dependencies

self._propagate_target_max_block_size_to_input(dag)

# Done fusing back-to-back map operators together here,
# move up the DAG to find the next map operators to fuse.
dag._input_dependencies = [
Expand All @@ -105,23 +102,11 @@ def _fuse_all_to_all_operators_in_dag(
len(upstream_ops) == 1
and isinstance(dag, AllToAllOperator)
and isinstance(upstream_ops[0], MapOperator)
and self._can_fuse(dag, upstream_ops[0])
):
if self._can_fuse(dag, upstream_ops[0]):
# Fuse operator with its upstream op.
dag = self._get_fused_all_to_all_operator(dag, upstream_ops[0])
upstream_ops = dag.input_dependencies
else:
# Propagate target max block size to the upstream map op. This
# is necessary even when fusion is not allowed, so that the map
# op will produce the right block size for the shuffle op to
# consume.
map_op = upstream_ops[0]
map_op._target_max_block_size = self._get_merged_target_max_block_size(
upstream_ops[0].target_max_block_size, dag.target_max_block_size
)
break

self._propagate_target_max_block_size_to_input(dag)
# Fuse operator with its upstream op.
dag = self._get_fused_all_to_all_operator(dag, upstream_ops[0])
upstream_ops = dag.input_dependencies

# Done fusing MapOperator -> AllToAllOperator together here,
# move up the DAG to find the next pair of operators to fuse.
Expand Down Expand Up @@ -262,24 +247,6 @@ def _get_merged_target_max_block_size(
# blocks.
return down_target_max_block_size

def _propagate_target_max_block_size_to_input(self, dag):
# Operator fusion will merge target block sizes for adjacent operators,
# but if dag is the first op after a stage with read tasks, then we
# also need to propagate the block size to the input data buffer.
upstream_ops = dag.input_dependencies
if (
len(upstream_ops) == 1
and isinstance(upstream_ops[0], InputDataBuffer)
and self._can_merge_target_max_block_size(
upstream_ops[0].target_max_block_size, dag.target_max_block_size
)
):
upstream_ops[
0
]._target_max_block_size = self._get_merged_target_max_block_size(
upstream_ops[0].target_max_block_size, dag.target_max_block_size
)

def _get_fused_map_operator(
self, down_op: MapOperator, up_op: MapOperator
) -> MapOperator:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -71,19 +71,32 @@ def compute_additional_split_factor(
return parallelism, reason, estimated_num_blocks, None


class SplitReadOutputBlocksRule(Rule):
class SetReadParallelismRule(Rule):
"""
This rule sets the read op's task parallelism based on the target block
size, the requested parallelism, the number of read files, and the
available resources in the cluster.
If the parallelism is lower than requested, this rule also sets a split
factor to split the output blocks of the read task, so that the following
stage will have the desired parallelism.
"""

def apply(self, plan: PhysicalPlan) -> PhysicalPlan:
ops = [plan.dag]

while len(ops) == 1 and not isinstance(ops[0], InputDataBuffer):
logical_op = plan.op_map[ops[0]]
while len(ops) > 0:
op = ops.pop(0)
if isinstance(op, InputDataBuffer):
continue
logical_op = plan.op_map[op]
if isinstance(logical_op, Read):
self._split_read_op_if_needed(ops[0], logical_op)
ops = ops[0].input_dependencies
self._apply(op, logical_op)
ops += op.input_dependencies

return plan

def _split_read_op_if_needed(self, op: PhysicalOperator, logical_op: Read):
def _apply(self, op: PhysicalOperator, logical_op: Read):
(
detected_parallelism,
reason,
Expand All @@ -96,12 +109,15 @@ def _split_read_op_if_needed(self, op: PhysicalOperator, logical_op: Read):
op.actual_target_max_block_size,
op._additional_split_factor,
)

if logical_op._parallelism == -1:
assert reason != ""
logger.get_logger().info(
f"Using autodetected parallelism={detected_parallelism} "
f"for stage {logical_op.name} to satisfy {reason}."
)
logical_op.set_detected_parallelism(detected_parallelism)

if k is not None:
logger.get_logger().info(
f"To satisfy the requested parallelism of {detected_parallelism}, "
Expand Down
2 changes: 1 addition & 1 deletion python/ray/data/_internal/plan.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@
from ray.data._internal.lazy_block_list import LazyBlockList
from ray.data._internal.logical.operators.read_operator import Read
from ray.data._internal.logical.rules.operator_fusion import _are_remote_args_compatible
from ray.data._internal.logical.rules.split_read_output_blocks import (
from ray.data._internal.logical.rules.set_read_parallelism import (
compute_additional_split_factor,
)
from ray.data._internal.planner.plan_read_op import (
Expand Down
14 changes: 5 additions & 9 deletions python/ray/data/_internal/planner/plan_read_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
MapTransformFn,
)
from ray.data._internal.logical.operators.read_operator import Read
from ray.data._internal.util import _autodetect_parallelism, _warn_on_high_parallelism
from ray.data._internal.util import _warn_on_high_parallelism
from ray.data.block import Block
from ray.data.context import DataContext
from ray.data.datasource.datasource import ReadTask
Expand Down Expand Up @@ -48,14 +48,10 @@ def plan_read_op(op: Read) -> PhysicalOperator:
"""

def get_input_data(target_max_block_size) -> List[RefBundle]:
parallelism, _, min_safe_parallelism, _ = _autodetect_parallelism(
op._parallelism,
target_max_block_size,
DataContext.get_current(),
op._datasource_or_legacy_reader,
op._mem_size,
)

parallelism = op.get_detected_parallelism()
assert (
parallelism is not None
), "Read parallelism must be set by the optimizer before execution"
read_tasks = op._datasource_or_legacy_reader.get_read_tasks(parallelism)
_warn_on_high_parallelism(parallelism, len(read_tasks))

Expand Down
23 changes: 15 additions & 8 deletions python/ray/data/tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -510,23 +510,24 @@ def get_object_store_stats(self):
def get_actor_count(self):
return self.actor_count

def _assert_count_equals(self, actual_count, expected_count):
def _assert_count_equals(self, actual_count, expected_count, ignore_extra_tasks):
diff = {}
# Check that all tasks in expected tasks match those in actual task
# count.
for name, count in expected_count.items():
if not equals_or_true(actual_count[name], count):
diff[name] = (actual_count[name], count)
# Check that the actual task count does not have any additional tasks.
for name, count in actual_count.items():
if name not in expected_count and count != 0:
diff[name] = (count, 0)
if not ignore_extra_tasks:
for name, count in actual_count.items():
if name not in expected_count and count != 0:
diff[name] = (count, 0)

assert len(diff) == 0, "\nTask diff:\n" + "\n".join(
f" - {key}: expected {val[1]}, got {val[0]}" for key, val in diff.items()
)

def assert_task_metrics(self, expected_metrics):
def assert_task_metrics(self, expected_metrics, ignore_extra_tasks):
"""
Assert equality to the given { <task name>: <task count> }.
A lambda that takes in the count and returns a bool to assert can also
Expand All @@ -545,7 +546,9 @@ def assert_task_metrics(self, expected_metrics):
expected_task_count[name] = count

actual_task_count = self.get_task_count()
self._assert_count_equals(actual_task_count, expected_task_count)
self._assert_count_equals(
actual_task_count, expected_task_count, ignore_extra_tasks
)

def assert_object_store_metrics(self, expected_metrics):
"""
Expand All @@ -568,6 +571,7 @@ def assert_object_store_metrics(self, expected_metrics):

actual_object_store_stats = self.get_object_store_stats()
for key, val in expected_object_store_stats.items():
print(f"{key}: Expect {val}, got {actual_object_store_stats[key]}")
assert equals_or_true(
actual_object_store_stats[key], val
), f"{key}: expected {val} got {actual_object_store_stats[key]}"
Expand Down Expand Up @@ -727,12 +731,15 @@ def get_initial_core_execution_metrics_snapshot():
task_count={"warmup": lambda count: True}, object_store_stats={}
),
last_snapshot=None,
ignore_extra_tasks=True,
)
return last_snapshot


def assert_core_execution_metrics_equals(
expected_metrics: CoreExecutionMetrics, last_snapshot=None
expected_metrics: CoreExecutionMetrics,
last_snapshot=None,
ignore_extra_tasks=False,
):
# Wait for one task per CPU to finish to prevent a race condition where not
# all of the task metrics have been collected yet.
Expand All @@ -742,7 +749,7 @@ def assert_core_execution_metrics_equals(
wait_for_condition(lambda: task_metrics_flushed(refs))

metrics = PhysicalCoreExecutionMetrics(last_snapshot)
metrics.assert_task_metrics(expected_metrics)
metrics.assert_task_metrics(expected_metrics, ignore_extra_tasks)
metrics.assert_object_store_metrics(expected_metrics)
metrics.assert_actor_metrics(expected_metrics)

Expand Down
Loading

0 comments on commit 6f7378c

Please sign in to comment.