Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[data] Slice output blocks to respect target block size #40248

Merged
merged 2 commits into from
Oct 12, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 18 additions & 3 deletions python/ray/data/_internal/output_buffer.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from typing import Any

from ray.data._internal.delegating_block_builder import DelegatingBlockBuilder
from ray.data.block import Block, DataBatch
from ray.data.block import Block, BlockAccessor, DataBatch


class BlockOutputBuffer:
Expand Down Expand Up @@ -68,7 +68,22 @@ def has_next(self) -> bool:
def next(self) -> Block:
"""Returns the next complete output block."""
assert self.has_next()
block = self._buffer.build()

block_to_yield = self._buffer.build()
block_remainder = None
block = BlockAccessor.for_block(block_to_yield)
if block.size_bytes() > self._target_max_block_size:
num_bytes_per_row = block.size_bytes() // block.num_rows()
target_num_rows = self._target_max_block_size // num_bytes_per_row
target_num_rows = max(1, target_num_rows)

num_rows = min(target_num_rows, block.num_rows())
block_to_yield = block.slice(0, num_rows)
block_remainder = block.slice(num_rows, block.num_rows())

self._buffer = DelegatingBlockBuilder()
if block_remainder is not None:
self._buffer.add_block(block_remainder)

self._returned_at_least_one_block = True
return block
return block_to_yield
8 changes: 5 additions & 3 deletions python/ray/data/tests/test_dynamic_block_split.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,7 @@ def test_dataset(
# Test 10 tasks, each task returning 10 blocks, each block has 1 row and each
# row has 1024 bytes.
num_blocks_per_task = 10
block_size = 1024
block_size = target_max_block_size
num_tasks = 10

ds = ray.data.read_datasource(
Expand All @@ -131,11 +131,13 @@ def test_dataset(
assert ds.num_blocks() == num_tasks
assert ds.size_bytes() >= 0.7 * block_size * num_blocks_per_task * num_tasks

# Too-large blocks will get split to respect target max block size.
map_ds = ds.map_batches(lambda x: x, compute=compute)
map_ds = map_ds.materialize()
assert map_ds.num_blocks() == num_tasks
assert map_ds.num_blocks() == num_tasks * num_blocks_per_task
# Blocks smaller than requested batch size will get coalesced.
map_ds = ds.map_batches(
lambda x: x, batch_size=num_blocks_per_task * num_tasks, compute=compute
lambda x: {}, batch_size=num_blocks_per_task * num_tasks, compute=compute
)
map_ds = map_ds.materialize()
assert map_ds.num_blocks() == 1
Expand Down
27 changes: 26 additions & 1 deletion python/ray/data/tests/test_splitblocks.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ def f(n, k):
f(50, 5)


def test_small_file_split(ray_start_10_cpus_shared):
def test_small_file_split(ray_start_10_cpus_shared, restore_data_context):
ds = ray.data.read_csv("example://iris.csv", parallelism=1)
assert ds.num_blocks() == 1
assert ds.materialize().num_blocks() == 1
Expand All @@ -48,6 +48,14 @@ def test_small_file_split(ray_start_10_cpus_shared):
assert "Stage 1 ReadCSV->SplitBlocks(100)" in stats, stats
assert "Stage 2 MapBatches" in stats, stats

ctx = ray.data.context.DataContext.get_current()
# Smaller than a single row.
ctx.target_max_block_size = 1
ds = ds.map_batches(lambda x: x).materialize()
# 150 rows.
assert ds.num_blocks() == 150
print(ds.stats())


def test_large_file_additional_split(ray_start_10_cpus_shared, tmp_path):
ctx = ray.data.context.DataContext.get_current()
Expand All @@ -74,6 +82,23 @@ def test_large_file_additional_split(ray_start_10_cpus_shared, tmp_path):
assert 500 < ds.materialize().num_blocks() < 2000


def test_map_batches_split(ray_start_10_cpus_shared, restore_data_context):
ds = ray.data.range(1000, parallelism=1).map_batches(lambda x: x, batch_size=1000)
assert ds.materialize().num_blocks() == 1

ctx = ray.data.context.DataContext.get_current()
# 100 integer rows per block.
ctx.target_max_block_size = 800

ds = ray.data.range(1000, parallelism=1).map_batches(lambda x: x, batch_size=1000)
assert ds.materialize().num_blocks() == 10

# A single row is already larger than the target block
# size.
ctx.target_max_block_size = 4
assert ds.materialize().num_blocks() == 1000


if __name__ == "__main__":
import sys

Expand Down
Loading