Skip to content

Commit

Permalink
[Datasets] Support different number of blocks/rows per block in zip(). (
Browse files Browse the repository at this point in the history
#32795) (#32998)

This PR adds support for a different number of blocks/rows per block in `ds1.zip(ds2)`, by aligning the blocks in `ds2` to `ds1` with a lightweight repartition/block splitting.

## Design

We heavily utilize the block splitting machinery that's use for `ds.split()` and `ds.split_at_indices()` to avoid an overly expensive repartition. Namely, for `ds1.zip(ds2)`, we:
1. Calculate the block sizes for `ds1` in order to get split indices.
2. Apply `_split_at_indices()` to `ds2` in order to get a list of `ds2` block chunks for every block in `ds1`, such that `self_block.num_rows() == sum(other_block.num_rows() for other_block in other_split_blocks)` for every `self_block` in `ds1`.
3. Zip together each block in `ds1` with the one or more blocks from `ds2` that constitute the block-aligned split for that `ds1` block.
  • Loading branch information
clarkzinzow authored Mar 3, 2023
1 parent 3cf8ce8 commit 6010649
Show file tree
Hide file tree
Showing 7 changed files with 247 additions and 59 deletions.
6 changes: 5 additions & 1 deletion python/ray/data/_internal/equalize.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,7 +152,11 @@ def _split_leftovers(
prev = split_indices[i]
split_result: Tuple[
List[List[ObjectRef[Block]]], List[List[BlockMetadata]]
] = _split_at_indices(leftovers, split_indices)
] = _split_at_indices(
leftovers.get_blocks_with_metadata(),
split_indices,
leftovers._owned_by_consumer,
)
return [list(zip(block_refs, meta)) for block_refs, meta in zip(*split_result)][
:num_splits
]
10 changes: 9 additions & 1 deletion python/ray/data/_internal/plan.py
Original file line number Diff line number Diff line change
Expand Up @@ -745,7 +745,15 @@ def _run_with_new_execution_backend(self) -> bool:
not self.is_read_stage_equivalent()
or trailing_randomize_block_order_stage
)
and self._stages_after_snapshot
and (
self._stages_after_snapshot
# If snapshot is cleared, we'll need to recompute from the source.
or (
self._snapshot_blocks is not None
and self._snapshot_blocks.is_cleared()
and self._stages_before_snapshot
)
)
)


Expand Down
26 changes: 18 additions & 8 deletions python/ray/data/_internal/split.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,13 +154,13 @@ def _drop_empty_block_split(block_split_indices: List[int], num_rows: int) -> Li


def _split_all_blocks(
block_list: BlockList,
blocks_with_metadata: List[Tuple[ObjectRef[Block], BlockMetadata]],
per_block_split_indices: List[List[int]],
owned_by_consumer: bool,
) -> Iterable[Tuple[ObjectRef[Block], BlockMetadata]]:
"""Split all the input blocks based on the split indices"""
split_single_block = cached_remote_fn(_split_single_block)

blocks_with_metadata = block_list.get_blocks_with_metadata()
all_blocks_split_results: List[BlockPartition] = [None] * len(blocks_with_metadata)

per_block_split_metadata_futures = []
Expand Down Expand Up @@ -203,7 +203,7 @@ def _split_all_blocks(
# We make a copy for the blocks that have been splitted, so the input blocks
# can be cleared if they are owned by consumer (consumer-owned blocks will
# only be consumed by the owner).
if block_list._owned_by_consumer:
if owned_by_consumer:
for b in blocks_splitted:
trace_deallocation(b, "split._split_all_blocks")
else:
Expand Down Expand Up @@ -246,26 +246,32 @@ def _generate_global_split_results(


def _split_at_indices(
block_list: BlockList,
blocks_with_metadata: List[Tuple[ObjectRef[Block], BlockMetadata]],
indices: List[int],
owned_by_consumer: bool = True,
block_rows: List[int] = None,
) -> Tuple[List[List[ObjectRef[Block]]], List[List[BlockMetadata]]]:
"""Split blocks at the provided indices.
Args:
blocks_with_metadata: Block futures to split, including the associated metadata.
indices: The (global) indices at which to split the blocks.
owned_by_consumer: Whether the provided blocks are owned by the consumer.
block_rows: The number of rows for each block, in case it has already been
computed.
Returns:
The block split futures and their metadata. If an index split is empty, the
corresponding block split will be empty .
"""

blocks_with_metadata = block_list.get_blocks_with_metadata()
# We implement the split in 3 phases.
# phase 1: calculate the per block split indices.
blocks_with_metadata = list(blocks_with_metadata)
if len(blocks_with_metadata) == 0:
return ([[]] * (len(indices) + 1), [[]] * (len(indices) + 1))
block_rows: List[int] = _calculate_blocks_rows(blocks_with_metadata)
if block_rows is None:
block_rows = _calculate_blocks_rows(blocks_with_metadata)
valid_indices = _generate_valid_indices(block_rows, indices)
per_block_split_indices: List[List[int]] = _generate_per_block_split_indices(
block_rows, valid_indices
Expand All @@ -274,7 +280,9 @@ def _split_at_indices(
# phase 2: split each block based on the indices from previous step.
all_blocks_split_results: Iterable[
Tuple[ObjectRef[Block], BlockMetadata]
] = _split_all_blocks(block_list, per_block_split_indices)
] = _split_all_blocks(
blocks_with_metadata, per_block_split_indices, owned_by_consumer
)

# phase 3: generate the final split.

Expand Down Expand Up @@ -306,5 +314,7 @@ def _split_at_index(
Returns:
The block split futures and their metadata for left and right of the index.
"""
blocks_splits, metadata_splits = _split_at_indices(block_list, [index])
blocks_splits, metadata_splits = _split_at_indices(
block_list.get_blocks_with_metadata(), [index], block_list._owned_by_consumer
)
return blocks_splits[0], metadata_splits[0], blocks_splits[1], metadata_splits[1]
158 changes: 129 additions & 29 deletions python/ray/data/_internal/stage_impl.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from typing import Any, Dict, Optional, TYPE_CHECKING
import itertools
from typing import Any, Dict, Tuple, List, Optional, TYPE_CHECKING

import ray
from ray.data._internal.fast_repartition import fast_repartition
Expand All @@ -7,14 +8,17 @@
PushBasedShufflePartitionOp,
SimpleShufflePartitionOp,
)
from ray.data._internal.split import _split_at_indices
from ray.data._internal.block_list import BlockList
from ray.data._internal.delegating_block_builder import DelegatingBlockBuilder
from ray.data._internal.execution.interfaces import TaskContext
from ray.data._internal.remote_fn import cached_remote_fn
from ray.data._internal.sort import sort_impl
from ray.data.context import DatasetContext
from ray.data.block import (
_validate_key_fn,
Block,
BlockPartition,
KeyFn,
BlockMetadata,
BlockAccessor,
Expand Down Expand Up @@ -147,50 +151,146 @@ class ZipStage(AllToAllStage):
"""Implementation of `Dataset.zip()`."""

def __init__(self, other: "Dataset"):
def do_zip_all(block_list, clear_input_blocks: bool, *_):
blocks1 = block_list.get_blocks()
blocks2 = other.get_internal_block_refs()

if clear_input_blocks:
block_list.clear()

if len(blocks1) != len(blocks2):
# TODO(ekl) consider supporting if num_rows are equal.
def do_zip_all(block_list: BlockList, clear_input_blocks: bool, *_):
# Repartition other to align with the base dataset, and then zip together
# the blocks in parallel.
# TODO(Clark): Port this to a streaming zip, e.g. push block pairs through
# an actor that buffers and zips.
base_block_list = block_list
base_blocks_with_metadata = block_list.get_blocks_with_metadata()
base_block_rows, base_block_bytes = _calculate_blocks_rows_and_bytes(
base_blocks_with_metadata
)
# Execute other to a block list.
other_block_list = other._plan.execute()
other_blocks_with_metadata = other_block_list.get_blocks_with_metadata()
other_block_rows, other_block_bytes = _calculate_blocks_rows_and_bytes(
other_blocks_with_metadata
)
inverted = False
if sum(other_block_bytes) > sum(base_block_bytes):
# Make sure that other is the smaller dataset, so we minimize splitting
# work when aligning other with base.
# TODO(Clark): Improve this heuristic for minimizing splitting work,
# e.g. by generating the splitting plans for each route (via
# _generate_per_block_split_indices) and choosing the plan that splits
# the least cumulative bytes.
base_block_list, other_block_list = other_block_list, base_block_list
base_blocks_with_metadata, other_blocks_with_metadata = (
other_blocks_with_metadata,
base_blocks_with_metadata,
)
base_block_rows, other_block_rows = other_block_rows, base_block_rows
inverted = True
# Get the split indices that will align other with base.
indices = list(itertools.accumulate(base_block_rows))
indices.pop(-1)

# Check that each dataset has the same number of rows.
# TODO(Clark): Support different number of rows via user-directed
# dropping/padding.
total_base_rows = sum(base_block_rows)
total_other_rows = sum(other_block_rows)
if total_base_rows != total_other_rows:
raise ValueError(
"Cannot zip dataset of different num blocks: {} vs {}".format(
len(blocks1), len(blocks2)
)
"Cannot zip datasets of different number of rows: "
f"{total_base_rows}, {total_other_rows}"
)

def do_zip(block1: Block, block2: Block) -> (Block, BlockMetadata):
stats = BlockExecStats.builder()
b1 = BlockAccessor.for_block(block1)
result = b1.zip(block2)
br = BlockAccessor.for_block(result)
return result, br.get_metadata(input_files=[], exec_stats=stats.build())
# Split other at the alignment indices, such that for every block in
# block_list, we have a list of blocks from other that has the same
# cumulative number of rows as that block.
# NOTE: _split_at_indices has a no-op fastpath if the blocks are already
# aligned.
aligned_other_blocks_with_metadata = _split_at_indices(
other_blocks_with_metadata,
indices,
other_block_list._owned_by_consumer,
other_block_rows,
)
del other_blocks_with_metadata

base_blocks = [b for b, _ in base_blocks_with_metadata]
other_blocks = aligned_other_blocks_with_metadata[0]
del base_blocks_with_metadata, aligned_other_blocks_with_metadata
if clear_input_blocks:
base_block_list.clear()
other_block_list.clear()

do_zip_fn = cached_remote_fn(do_zip, num_returns=2)
do_zip = cached_remote_fn(_do_zip, num_returns=2)

blocks = []
metadata = []
for b1, b2 in zip(blocks1, blocks2):
res, meta = do_zip_fn.remote(b1, b2)
blocks.append(res)
metadata.append(meta)
out_blocks = []
out_metadata = []
for base_block, other_blocks in zip(base_blocks, other_blocks):
# For each block in base, zip it together with 1 or more blocks from
# other. We're guaranteed to have that base_block has the same number of
# rows as other_blocks has cumulatively.
res, meta = do_zip.remote(base_block, *other_blocks, inverted=inverted)
out_blocks.append(res)
out_metadata.append(meta)

# Early release memory.
del blocks1, blocks2
del base_blocks, other_blocks

# TODO(ekl) it might be nice to have a progress bar here.
metadata = ray.get(metadata)
out_metadata = ray.get(out_metadata)
blocks = BlockList(
blocks, metadata, owned_by_consumer=block_list._owned_by_consumer
out_blocks,
out_metadata,
owned_by_consumer=base_block_list._owned_by_consumer,
)
return blocks, {}

super().__init__("zip", None, do_zip_all)


def _calculate_blocks_rows_and_bytes(
blocks_with_metadata: BlockPartition,
) -> Tuple[List[int], List[int]]:
"""Calculate the number of rows and size in bytes for a list of blocks with
metadata.
"""
get_num_rows_and_bytes = cached_remote_fn(_get_num_rows_and_bytes)
block_rows = []
block_bytes = []
for block, metadata in blocks_with_metadata:
if metadata.num_rows is None or metadata.size_bytes is None:
# Need to fetch number of rows or size in bytes, so just fetch both.
num_rows, size_bytes = ray.get(get_num_rows_and_bytes.remote(block))
# Cache on the block metadata.
metadata.num_rows = num_rows
metadata.size_bytes = size_bytes
block_rows.append(metadata.num_rows)
block_bytes.append(metadata.size_bytes)
return block_rows, block_bytes


def _get_num_rows_and_bytes(block: Block) -> Tuple[int, int]:
block = BlockAccessor.for_block(block)
return block.num_rows(), block.size_bytes()


def _do_zip(
block: Block, *other_blocks: Block, inverted: bool = False
) -> Tuple[Block, BlockMetadata]:
# Zips together block with other_blocks.
stats = BlockExecStats.builder()
# Concatenate other blocks.
# TODO(Clark): Extend BlockAccessor.zip() to work with N other blocks,
# so we don't need to do this concatenation.
builder = DelegatingBlockBuilder()
for other_block in other_blocks:
builder.add_block(other_block)
other_block = builder.build()
if inverted:
# Swap blocks if ordering was inverted during block alignment splitting.
block, other_block = other_block, block
# Zip block and other blocks.
result = BlockAccessor.for_block(block).zip(other_block)
br = BlockAccessor.for_block(result)
return result, br.get_metadata(input_files=[], exec_stats=stats.build())


class SortStage(AllToAllStage):
"""Implementation of `Dataset.sort()`."""

Expand Down
39 changes: 26 additions & 13 deletions python/ray/data/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -1363,7 +1363,11 @@ def split_at_indices(self, indices: List[int]) -> List["Dataset[T]"]:
raise ValueError("indices must be positive")
start_time = time.perf_counter()
block_list = self._plan.execute()
blocks, metadata = _split_at_indices(block_list, indices)
blocks, metadata = _split_at_indices(
block_list.get_blocks_with_metadata(),
indices,
block_list._owned_by_consumer,
)
split_duration = time.perf_counter() - start_time
parent_stats = self._plan.stats()
splits = []
Expand Down Expand Up @@ -2029,29 +2033,38 @@ def sort(
def zip(self, other: "Dataset[U]") -> "Dataset[(T, U)]":
"""Zip this dataset with the elements of another.
The datasets must have identical num rows, block types, and block sizes,
e.g. one was produced from a :meth:`~.map` of another. For Arrow
blocks, the schema will be concatenated, and any duplicate column
names disambiguated with _1, _2, etc. suffixes.
The datasets must have the same number of rows. For tabular datasets, the
datasets will be concatenated horizontally; namely, their column sets will be
merged, and any duplicate column names disambiguated with _1, _2, etc. suffixes.
.. note::
The smaller of the two datasets will be repartitioned to align the number of
rows per block with the larger dataset.
.. note::
Zipped datasets are not lineage-serializable, i.e. they can not be used as a
tunable hyperparameter in Ray Tune.
Examples:
>>> import ray
>>> ds1 = ray.data.range(5)
>>> ds2 = ray.data.range(5, parallelism=2).map(lambda x: x + 1)
>>> ds1.zip(ds2).take()
[(0, 1), (1, 2), (2, 3), (3, 4), (4, 5)]
Time complexity: O(dataset size / parallelism)
Args:
other: The dataset to zip with on the right hand side.
Examples:
>>> import ray
>>> ds = ray.data.range(5)
>>> ds.zip(ds).take()
[(0, 0), (1, 1), (2, 2), (3, 3), (4, 4)]
Returns:
A Dataset with (k, v) pairs (or concatenated Arrow schema) where k
comes from the first dataset and v comes from the second.
If the inputs are simple datasets, this returns a ``Dataset`` containing
(k, v) pairs, where k comes from the first dataset and v comes from the
second.
If the inputs are tabular datasets, this returns a ``Dataset`` containing
the columns of the second dataset concatenated horizontally with the columns
of the first dataset, with duplicate column names disambiguated with _1, _2,
etc. suffixes.
"""

plan = self._plan.with_stage(ZipStage(other))
Expand Down
Loading

0 comments on commit 6010649

Please sign in to comment.