Skip to content

Commit

Permalink
[Data] Update ExecutionPlan.execute_to_iterator() to return `RefBun…
Browse files Browse the repository at this point in the history
…dles` instead of `(Block, BlockMetadata)` (#46575)

Followup to #46369 and
#46455.
Update `ExecutionPlan.execute_to_iterator()` to return `RefBundles`
instead of `(Block, BlockMetadata)`, to unify the logic between
`RefBundle`s and `Block`s. Also refactor the `iter_batches()` code path
accordingly to handle `RefBundle`s instead of raw `Block` and
`BlockMetadata`.

Signed-off-by: sjl <sjl@anyscale.com>
Signed-off-by: Scott Lee <sjl@anyscale.com>
  • Loading branch information
scottjlee committed Jul 16, 2024
1 parent 17dfcfd commit 1b0af29
Show file tree
Hide file tree
Showing 9 changed files with 124 additions and 146 deletions.
2 changes: 1 addition & 1 deletion python/ray/data/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ py_library(
py_test_module_list(
files = glob(["tests/block_batching/test_*.py"]),
size = "medium",
tags = ["team:ml", "exclusive"],
tags = ["team:data", "exclusive"],
deps = ["//:ray_lib", ":conftest"],
)

Expand Down
47 changes: 27 additions & 20 deletions python/ray/data/_internal/block_batching/iter_batches.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import collections
from contextlib import nullcontext
from typing import Any, Callable, Dict, Iterator, Optional, Tuple
from typing import Any, Callable, Dict, Iterator, Optional

import ray
from ray.data._internal.block_batching.interfaces import Batch, BlockPrefetcher
Expand All @@ -14,16 +14,17 @@
format_batches,
resolve_block_refs,
)
from ray.data._internal.execution.interfaces.ref_bundle import RefBundle
from ray.data._internal.memory_tracing import trace_deallocation
from ray.data._internal.stats import DatasetStats
from ray.data._internal.util import make_async_gen
from ray.data.block import Block, BlockMetadata, DataBatch
from ray.data.block import Block, DataBatch
from ray.data.context import DataContext
from ray.types import ObjectRef


def iter_batches(
block_refs: Iterator[Tuple[ObjectRef[Block], BlockMetadata]],
ref_bundles: Iterator[RefBundle],
*,
stats: Optional[DatasetStats] = None,
clear_block_after_read: bool = False,
Expand Down Expand Up @@ -71,8 +72,7 @@ def iter_batches(
6. Fetch outputs from the threadpool, maintaining order of the batches.
Args:
block_refs: An iterator over block object references and their corresponding
metadata.
ref_bundles: An iterator over RefBundles.
stats: DatasetStats object to record timing and other statistics.
clear_block_after_read: Whether to clear the block from object store
manually (i.e. without waiting for Python's automatic GC) after it
Expand Down Expand Up @@ -121,19 +121,19 @@ def iter_batches(
eager_free = clear_block_after_read and DataContext.get_current().eager_free

def _async_iter_batches(
block_refs: Iterator[Tuple[ObjectRef[Block], BlockMetadata]],
ref_bundles: Iterator[RefBundle],
) -> Iterator[DataBatch]:
# Step 1: Prefetch logical batches locally.
block_refs = prefetch_batches_locally(
block_ref_iter=block_refs,
block_iter = prefetch_batches_locally(
ref_bundles=ref_bundles,
prefetcher=prefetcher,
num_batches_to_prefetch=prefetch_batches,
batch_size=batch_size,
eager_free=eager_free,
)

# Step 2: Resolve the blocks.
block_iter = resolve_block_refs(block_ref_iter=block_refs, stats=stats)
block_iter = resolve_block_refs(block_ref_iter=block_iter, stats=stats)

# Step 3: Batch and shuffle the resolved blocks.
batch_iter = blocks_to_batches(
Expand Down Expand Up @@ -168,7 +168,9 @@ def _async_iter_batches(

# Run everything in a separate thread to not block the main thread when waiting
# for streaming results.
async_batch_iter = make_async_gen(block_refs, fn=_async_iter_batches, num_workers=1)
async_batch_iter = make_async_gen(
ref_bundles, fn=_async_iter_batches, num_workers=1
)

while True:
with stats.iter_total_blocked_s.timer() if stats else nullcontext():
Expand Down Expand Up @@ -229,17 +231,18 @@ def threadpool_computations_format_collate(


def prefetch_batches_locally(
block_ref_iter: Iterator[Tuple[ObjectRef[Block], BlockMetadata]],
ref_bundles: Iterator[RefBundle],
prefetcher: BlockPrefetcher,
num_batches_to_prefetch: int,
batch_size: Optional[int],
eager_free: bool = False,
) -> Iterator[ObjectRef[Block]]:
"""Given an iterator of batched block references, returns an iterator over the same
block references while prefetching `num_batches_to_prefetch` batches in advance.
"""Given an iterator of batched RefBundles, returns an iterator over the
corresponding block references while prefetching `num_batches_to_prefetch`
batches in advance.
Args:
block_ref_iter: An iterator over batched block references.
ref_bundles: An iterator over batched RefBundles.
prefetcher: The prefetcher to use.
num_batches_to_prefetch: The number of batches to prefetch ahead of the
current batch during the scan.
Expand All @@ -251,8 +254,9 @@ def prefetch_batches_locally(
current_window_size = 0

if num_batches_to_prefetch <= 0:
for block_ref, metadata in block_ref_iter:
yield block_ref
for ref_bundle in ref_bundles:
for block_ref in ref_bundle.block_refs:
yield block_ref
return

if batch_size is not None:
Expand All @@ -268,11 +272,11 @@ def prefetch_batches_locally(
batch_size is None and len(sliding_window) < num_batches_to_prefetch
):
try:
next_block_ref_and_metadata = next(block_ref_iter)
next_ref_bundle = next(ref_bundles)
sliding_window.extend(next_ref_bundle.blocks)
current_window_size += next_ref_bundle.num_rows()
except StopIteration:
break
sliding_window.append(next_block_ref_and_metadata)
current_window_size += next_block_ref_and_metadata[1].num_rows

prefetcher.prefetch_blocks([block_ref for block_ref, _ in list(sliding_window)])

Expand All @@ -281,7 +285,10 @@ def prefetch_batches_locally(
current_window_size -= metadata.num_rows
if batch_size is None or current_window_size < num_rows_to_prefetch:
try:
sliding_window.append(next(block_ref_iter))
next_ref_bundle = next(ref_bundles)
for block_ref_and_md in next_ref_bundle.blocks:
sliding_window.append(block_ref_and_md)
current_window_size += block_ref_and_md[1].num_rows
prefetcher.prefetch_blocks(
[block_ref for block_ref, _ in list(sliding_window)]
)
Expand Down
15 changes: 5 additions & 10 deletions python/ray/data/_internal/iterator/iterator_impl.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,9 @@
from typing import TYPE_CHECKING, Iterator, Optional, Tuple, Union

from ray.data._internal.execution.interfaces.ref_bundle import RefBundle
from ray.data._internal.stats import DatasetStats
from ray.data._internal.util import create_dataset_tag
from ray.data.block import Block, BlockMetadata
from ray.data.iterator import DataIterator
from ray.types import ObjectRef

if TYPE_CHECKING:
import pyarrow
Expand All @@ -22,17 +21,13 @@ def __init__(
def __repr__(self) -> str:
return f"DataIterator({self._base_dataset})"

def _to_block_iterator(
def _to_ref_bundle_iterator(
self,
) -> Tuple[
Iterator[Tuple[ObjectRef[Block], BlockMetadata]],
Optional[DatasetStats],
bool,
]:
) -> Tuple[Iterator[RefBundle], Optional[DatasetStats], bool]:
ds = self._base_dataset
block_iterator, stats, executor = ds._plan.execute_to_iterator()
ref_bundles_iterator, stats, executor = ds._plan.execute_to_iterator()
ds._current_executor = executor
return block_iterator, stats, False
return ref_bundles_iterator, stats, False

def stats(self) -> str:
return self._base_dataset.stats()
Expand Down
20 changes: 8 additions & 12 deletions python/ray/data/_internal/iterator/stream_split_iterator.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,31 +68,27 @@ def __init__(
self._world_size = world_size
self._iter_stats = DatasetStats(metadata={}, parent=None)

def _to_block_iterator(
def _to_ref_bundle_iterator(
self,
) -> Tuple[
Iterator[Tuple[ObjectRef[Block], BlockMetadata]],
Optional[DatasetStats],
bool,
]:
def gen_blocks() -> Iterator[Tuple[ObjectRef[Block], BlockMetadata]]:
) -> Tuple[Iterator[RefBundle], Optional[DatasetStats], bool]:
def gen_blocks() -> Iterator[RefBundle]:
cur_epoch = ray.get(
self._coord_actor.start_epoch.remote(self._output_split_idx)
)
future: ObjectRef[
Optional[ObjectRef[Block]]
] = self._coord_actor.get.remote(cur_epoch, self._output_split_idx)
while True:
block_ref: Optional[Tuple[ObjectRef[Block], BlockMetadata]] = ray.get(
future
)
if not block_ref:
block_ref_and_md: Optional[
Tuple[ObjectRef[Block], BlockMetadata]
] = ray.get(future)
if not block_ref_and_md:
break
else:
future = self._coord_actor.get.remote(
cur_epoch, self._output_split_idx
)
yield block_ref
yield RefBundle(blocks=(block_ref_and_md,), owns_blocks=False)

return gen_blocks(), self._iter_stats, False

Expand Down
52 changes: 23 additions & 29 deletions python/ray/data/_internal/plan.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,9 @@
from ray.data._internal.logical.operators.read_operator import Read
from ray.data._internal.stats import DatasetStats, DatasetStatsSummary
from ray.data._internal.util import create_dataset_tag, unify_block_metadata_schema
from ray.data.block import Block, BlockMetadata
from ray.data.block import BlockMetadata
from ray.data.context import DataContext
from ray.data.exceptions import omit_traceback_stdout
from ray.types import ObjectRef
from ray.util.debug import log_once

if TYPE_CHECKING:
Expand Down Expand Up @@ -346,20 +345,24 @@ def schema(
elif self._logical_plan.dag.schema() is not None:
schema = self._logical_plan.dag.schema()
elif fetch_if_missing:
blocks_with_metadata, _, _ = self.execute_to_iterator()
for _, metadata in blocks_with_metadata:
if metadata.schema is not None and (
metadata.num_rows is None or metadata.num_rows > 0
):
schema = metadata.schema
break
iter_ref_bundles, _, _ = self.execute_to_iterator()
for ref_bundle in iter_ref_bundles:
for metadata in ref_bundle.metadata:
if metadata.schema is not None and (
metadata.num_rows is None or metadata.num_rows > 0
):
schema = metadata.schema
break
elif self.is_read_only():
# For consistency with the previous implementation, we fetch the schema if
# the plan is read-only even if `fetch_if_missing` is False.
blocks_with_metadata, _, _ = self.execute_to_iterator()
iter_ref_bundles, _, _ = self.execute_to_iterator()
try:
_, metadata = next(iter(blocks_with_metadata))
schema = metadata.schema
ref_bundle = next(iter(iter_ref_bundles))
for metadata in ref_bundle.metadata:
if metadata.schema is not None:
schema = metadata.schema
break
except StopIteration: # Empty dataset.
schema = None

Expand Down Expand Up @@ -392,17 +395,13 @@ def meta_count(self) -> Optional[int]:
@omit_traceback_stdout
def execute_to_iterator(
self,
) -> Tuple[
Iterator[Tuple[ObjectRef[Block], BlockMetadata]],
DatasetStats,
Optional["Executor"],
]:
) -> Tuple[Iterator[RefBundle], DatasetStats, Optional["Executor"]]:
"""Execute this plan, returning an iterator.
This will use streaming execution to generate outputs.
Returns:
Tuple of iterator over output blocks and the executor.
Tuple of iterator over output RefBundles, DatasetStats, and the executor.
"""
self._has_started_execution = True

Expand All @@ -411,30 +410,25 @@ def execute_to_iterator(

if self.has_computed_output():
bundle = self.execute()
return iter(bundle.blocks), self._snapshot_stats, None
return iter([bundle]), self._snapshot_stats, None

from ray.data._internal.execution.legacy_compat import (
execute_to_legacy_block_iterator,
execute_to_legacy_bundle_iterator,
)
from ray.data._internal.execution.streaming_executor import StreamingExecutor

metrics_tag = create_dataset_tag(self._dataset_name, self._dataset_uuid)
executor = StreamingExecutor(copy.deepcopy(ctx.execution_options), metrics_tag)
# TODO(scottjlee): replace with `execute_to_legacy_bundle_iterator` and
# update execute_to_iterator usages to handle RefBundles instead of Blocks
block_iter = execute_to_legacy_block_iterator(
executor,
self,
)
bundle_iter = execute_to_legacy_bundle_iterator(executor, self)
# Since the generator doesn't run any code until we try to fetch the first
# value, force execution of one bundle before we call get_stats().
gen = iter(block_iter)
gen = iter(bundle_iter)
try:
block_iter = itertools.chain([next(gen)], gen)
bundle_iter = itertools.chain([next(gen)], gen)
except StopIteration:
pass
self._snapshot_stats = executor.get_stats()
return block_iter, self._snapshot_stats, executor
return bundle_iter, self._snapshot_stats, executor

@omit_traceback_stdout
def execute(
Expand Down
10 changes: 1 addition & 9 deletions python/ray/data/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,6 @@
VALID_BATCH_FORMATS,
Block,
BlockAccessor,
BlockMetadata,
DataBatch,
T,
U,
Expand Down Expand Up @@ -4682,14 +4681,7 @@ def iter_internal_ref_bundles(self) -> Iterator[RefBundle]:
An iterator over this Dataset's ``RefBundles``.
"""

def _build_ref_bundles(
iter_blocks: Iterator[Tuple[ObjectRef[Block], BlockMetadata]],
) -> Iterator[RefBundle]:
for block in iter_blocks:
yield RefBundle((block,), owns_blocks=True)

iter_block_refs_md, _, _ = self._plan.execute_to_iterator()
iter_ref_bundles = _build_ref_bundles(iter_block_refs_md)
iter_ref_bundles, _, _ = self._plan.execute_to_iterator()
self._synchronize_progress_bar()
return iter_ref_bundles

Expand Down
Loading

0 comments on commit 1b0af29

Please sign in to comment.