diff --git a/python/ray/data/_internal/output_buffer.py b/python/ray/data/_internal/output_buffer.py index b3a0c54823cb2..dcd07fe4ebcfa 100644 --- a/python/ray/data/_internal/output_buffer.py +++ b/python/ray/data/_internal/output_buffer.py @@ -1,7 +1,7 @@ from typing import Any from ray.data._internal.delegating_block_builder import DelegatingBlockBuilder -from ray.data.block import Block, DataBatch +from ray.data.block import Block, BlockAccessor, DataBatch class BlockOutputBuffer: @@ -68,7 +68,22 @@ def has_next(self) -> bool: def next(self) -> Block: """Returns the next complete output block.""" assert self.has_next() - block = self._buffer.build() + + block_to_yield = self._buffer.build() + block_remainder = None + block = BlockAccessor.for_block(block_to_yield) + if block.size_bytes() > self._target_max_block_size: + num_bytes_per_row = block.size_bytes() // block.num_rows() + target_num_rows = self._target_max_block_size // num_bytes_per_row + target_num_rows = max(1, target_num_rows) + + num_rows = min(target_num_rows, block.num_rows()) + block_to_yield = block.slice(0, num_rows) + block_remainder = block.slice(num_rows, block.num_rows()) + self._buffer = DelegatingBlockBuilder() + if block_remainder is not None: + self._buffer.add_block(block_remainder) + self._returned_at_least_one_block = True - return block + return block_to_yield diff --git a/python/ray/data/tests/test_dynamic_block_split.py b/python/ray/data/tests/test_dynamic_block_split.py index 8faa975243bad..c910507a54bf5 100644 --- a/python/ray/data/tests/test_dynamic_block_split.py +++ b/python/ray/data/tests/test_dynamic_block_split.py @@ -116,7 +116,7 @@ def test_dataset( # Test 10 tasks, each task returning 10 blocks, each block has 1 row and each # row has 1024 bytes. num_blocks_per_task = 10 - block_size = 1024 + block_size = target_max_block_size num_tasks = 10 ds = ray.data.read_datasource( @@ -131,11 +131,13 @@ def test_dataset( assert ds.num_blocks() == num_tasks assert ds.size_bytes() >= 0.7 * block_size * num_blocks_per_task * num_tasks + # Too-large blocks will get split to respect target max block size. map_ds = ds.map_batches(lambda x: x, compute=compute) map_ds = map_ds.materialize() - assert map_ds.num_blocks() == num_tasks + assert map_ds.num_blocks() == num_tasks * num_blocks_per_task + # Blocks smaller than requested batch size will get coalesced. map_ds = ds.map_batches( - lambda x: x, batch_size=num_blocks_per_task * num_tasks, compute=compute + lambda x: {}, batch_size=num_blocks_per_task * num_tasks, compute=compute ) map_ds = map_ds.materialize() assert map_ds.num_blocks() == 1 diff --git a/python/ray/data/tests/test_splitblocks.py b/python/ray/data/tests/test_splitblocks.py index 629d0f360dfc9..33c0948a7f946 100644 --- a/python/ray/data/tests/test_splitblocks.py +++ b/python/ray/data/tests/test_splitblocks.py @@ -23,7 +23,7 @@ def f(n, k): f(50, 5) -def test_small_file_split(ray_start_10_cpus_shared): +def test_small_file_split(ray_start_10_cpus_shared, restore_data_context): ds = ray.data.read_csv("example://iris.csv", parallelism=1) assert ds.num_blocks() == 1 assert ds.materialize().num_blocks() == 1 @@ -48,6 +48,14 @@ def test_small_file_split(ray_start_10_cpus_shared): assert "Stage 1 ReadCSV->SplitBlocks(100)" in stats, stats assert "Stage 2 MapBatches" in stats, stats + ctx = ray.data.context.DataContext.get_current() + # Smaller than a single row. + ctx.target_max_block_size = 1 + ds = ds.map_batches(lambda x: x).materialize() + # 150 rows. + assert ds.num_blocks() == 150 + print(ds.stats()) + def test_large_file_additional_split(ray_start_10_cpus_shared, tmp_path): ctx = ray.data.context.DataContext.get_current() @@ -74,6 +82,23 @@ def test_large_file_additional_split(ray_start_10_cpus_shared, tmp_path): assert 500 < ds.materialize().num_blocks() < 2000 +def test_map_batches_split(ray_start_10_cpus_shared, restore_data_context): + ds = ray.data.range(1000, parallelism=1).map_batches(lambda x: x, batch_size=1000) + assert ds.materialize().num_blocks() == 1 + + ctx = ray.data.context.DataContext.get_current() + # 100 integer rows per block. + ctx.target_max_block_size = 800 + + ds = ray.data.range(1000, parallelism=1).map_batches(lambda x: x, batch_size=1000) + assert ds.materialize().num_blocks() == 10 + + # A single row is already larger than the target block + # size. + ctx.target_max_block_size = 4 + assert ds.materialize().num_blocks() == 1000 + + if __name__ == "__main__": import sys