Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[RFC][Dataset] Actor based prefetching #23952

Merged
merged 10 commits into from
Apr 29, 2022
6 changes: 6 additions & 0 deletions python/ray/data/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,9 @@
# Whether to furthermore fuse prior map tasks with shuffle stages.
DEFAULT_OPTIMIZE_FUSE_SHUFFLE_STAGES = True

# Wether to use actor based block prefetcher.
DEFAULT_ACTOR_PREFETCHER_ENABLED = True

# Whether to use push-based shuffle by default.
DEFAULT_USE_PUSH_BASED_SHUFFLE = bool(
os.environ.get("RAY_DATASET_PUSH_BASED_SHUFFLE", None)
Expand All @@ -52,6 +55,7 @@ def __init__(
optimize_fuse_stages: bool,
optimize_fuse_read_stages: bool,
optimize_fuse_shuffle_stages: bool,
actor_prefetcher_enabled: bool,
use_push_based_shuffle: bool,
):
"""Private constructor (use get_current() instead)."""
Expand All @@ -62,6 +66,7 @@ def __init__(
self.optimize_fuse_stages = optimize_fuse_stages
self.optimize_fuse_read_stages = optimize_fuse_read_stages
self.optimize_fuse_shuffle_stages = optimize_fuse_shuffle_stages
self.actor_prefetcher_enabled = actor_prefetcher_enabled
self.use_push_based_shuffle = use_push_based_shuffle

@staticmethod
Expand All @@ -84,6 +89,7 @@ def get_current() -> "DatasetContext":
optimize_fuse_stages=DEFAULT_OPTIMIZE_FUSE_STAGES,
optimize_fuse_read_stages=DEFAULT_OPTIMIZE_FUSE_READ_STAGES,
optimize_fuse_shuffle_stages=DEFAULT_OPTIMIZE_FUSE_SHUFFLE_STAGES,
actor_prefetcher_enabled=DEFAULT_ACTOR_PREFETCHER_ENABLED,
use_push_based_shuffle=DEFAULT_USE_PUSH_BASED_SHUFFLE,
)

Expand Down
62 changes: 61 additions & 1 deletion python/ray/data/impl/block_batching.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,13 +9,17 @@
import numpy as np

import ray
from ray.actor import ActorHandle
from ray.types import ObjectRef
from ray.data.block import Block, BlockAccessor
from ray.data.context import DatasetContext
from ray.data.impl.batcher import Batcher
from ray.data.impl.stats import DatasetStats, DatasetPipelineStats
from ray.util.scheduling_strategies import NodeAffinitySchedulingStrategy

# An output type of iter_batches() determined by the batch_format parameter.
BatchType = Union["pandas.DataFrame", "pyarrow.Table", np.ndarray, list]
PREFETCHER_ACTOR_NAMESPACE = "ray.dataset"


def batch_blocks(
Expand Down Expand Up @@ -61,10 +65,19 @@ def batch_block(block: ObjectRef[Block]):
yield result

block_window = [] # Handle empty sliding window gracefully.
context = DatasetContext.get_current()
if (
prefetch_blocks > 0
and context.actor_prefetcher_enabled
and not ray.util.client.ray.is_connected()
):
prefetcher = ActorBlockPrefetcher()
else:
prefetcher = WaitBlockPrefetcher()
for block_window in _sliding_window(blocks, prefetch_blocks + 1):
block_window = list(block_window)
with stats.iter_wait_s.timer():
ray.wait(block_window, num_returns=1, fetch_local=True)
prefetcher.prefetch_blocks(block_window)
yield from batch_block(block_window[0])

# Consume remainder of final block window.
Expand Down Expand Up @@ -128,3 +141,50 @@ def _sliding_window(iterable: Iterable, n: int):
for elem in it:
window.append(elem)
yield tuple(window)


class BlockPrefetcher:
"""Interface for prefetching blocks."""

def prefetch_blocks(self, blocks: ObjectRef[Block]):
"""Prefetch the provided blocks to this node."""
raise NotImplementedError


class WaitBlockPrefetcher(BlockPrefetcher):
"""Block prefetcher using ray.wait."""

def prefetch_blocks(self, blocks: ObjectRef[Block]):
ray.wait(blocks, num_returns=1, fetch_local=True)


# ray.wait doesn't work as expected, so we have an
# actor-based prefetcher as a work around. See
# https://github.com/ray-project/ray/issues/23983 for details.
class ActorBlockPrefetcher(BlockPrefetcher):
"""Block prefetcher using a local actor."""

def __init__(self):
self.prefetch_actor = self._get_or_create_actor_prefetcher()

@staticmethod
def _get_or_create_actor_prefetcher() -> "ActorHandle":
node_id = ray.get_runtime_context().node_id
actor_name = f"dataset-block-prefetcher-{node_id}"
return _BlockPretcher.options(
scheduling_strategy=NodeAffinitySchedulingStrategy(node_id, soft=False),
name=actor_name,
namespace=PREFETCHER_ACTOR_NAMESPACE,
get_if_exists=True,
).remote()

def prefetch_blocks(self, blocks: ObjectRef[Block]):
self.prefetch_actor.prefetch.remote(*blocks)


@ray.remote(num_cpus=0, placement_group=None)
class _BlockPretcher:
"""Helper actor that prefetches blocks asynchronously."""

def prefetch(self, *blocks) -> None:
pass
10 changes: 10 additions & 0 deletions python/ray/data/tests/test_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from ray.data.dataset import Dataset, _sliding_window
from ray.data.datasource.csv_datasource import CSVDatasource
from ray.data.block import BlockAccessor
from ray.data.context import DatasetContext
from ray.data.row import TableRow
from ray.data.impl.arrow_block import ArrowRow
from ray.data.impl.block_builder import BlockBuilder
Expand Down Expand Up @@ -1445,6 +1446,15 @@ def test_iter_batches_basic(ray_start_regular_shared):
assert isinstance(batch, pd.DataFrame)
assert batch.equals(df)

# Prefetch with ray.wait.
context = DatasetContext.get_current()
context.actor_prefetcher_enabled = False
batches = list(ds.iter_batches(prefetch_blocks=1, batch_format="pandas"))
assert len(batches) == len(dfs)
for batch, df in zip(batches, dfs):
assert isinstance(batch, pd.DataFrame)
assert batch.equals(df)


def test_iter_batches_grid(ray_start_regular_shared):
# Tests slicing, batch combining, and partial batch dropping logic over
Expand Down