Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Dataset] Improve str/repr of Dataset to include execution plan #31604

Merged
merged 2 commits into from Jan 12, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
90 changes: 89 additions & 1 deletion python/ray/data/_internal/plan.py
Expand Up @@ -136,6 +136,81 @@ def __repr__(self) -> str:
f"snapshot_blocks={self._snapshot_blocks})"
)

def get_plan_as_string(self) -> str:
"""Create a cosmetic string representation of this execution plan.

Returns:
The string representation of this execution plan.
"""

# NOTE: this is used for Dataset.__repr__ to give a user-facing string
# representation. Ideally ExecutionPlan.__repr__ should be replaced with this
# method as well.

# Do not force execution for schema, as this method is expected to be very
# cheap.
plan_str = ""
num_stages = 0
dataset_blocks = None
if self._stages_after_snapshot:
# Get string representation of each stage in reverse order.
for stage in self._stages_after_snapshot[::-1]:
# Get name of each stage in camel case.
stage_name = stage.name.title().replace("_", "")
if num_stages == 0:
plan_str += f"{stage_name}\n"
else:
trailing_space = " " * ((num_stages - 1) * 3)
plan_str += f"{trailing_space}+- {stage_name}\n"
num_stages += 1

# Get schema of initial blocks.
if self._snapshot_blocks is not None:
schema = self._get_unified_blocks_schema(
self._snapshot_blocks, fetch_if_missing=False
)
dataset_blocks = self._snapshot_blocks
else:
assert self._in_blocks is not None
schema = self._get_unified_blocks_schema(
self._in_blocks, fetch_if_missing=False
)
dataset_blocks = self._in_blocks
else:
# Get schema of output blocks.
schema = self.schema(fetch_if_missing=False)
dataset_blocks = self._snapshot_blocks

if schema is None:
schema_str = "Unknown schema"
elif isinstance(schema, type):
schema_str = str(schema)
else:
schema_str = []
for n, t in zip(schema.names, schema.types):
if hasattr(t, "__name__"):
t = t.__name__
schema_str.append(f"{n}: {t}")
schema_str = ", ".join(schema_str)
schema_str = "{" + schema_str + "}"
count = self._get_num_rows_from_blocks_metadata(dataset_blocks)
if count is None:
count = "?"
if dataset_blocks is None:
num_blocks = "?"
else:
num_blocks = dataset_blocks.initial_num_blocks()
dataset_str = "Dataset(num_blocks={}, num_rows={}, schema={})".format(
num_blocks, count, schema_str
)

if num_stages == 0:
plan_str = dataset_str
else:
trailing_space = " " * ((num_stages - 1) * 3)
plan_str += f"{trailing_space}+- {dataset_str}"
return plan_str

def with_stage(self, stage: "Stage") -> "ExecutionPlan":
"""Return a copy of this plan with the given stage appended.

Expand Down Expand Up @@ -260,6 +335,17 @@ def schema(
blocks = self._snapshot_blocks
if not blocks:
return None
return self._get_unified_blocks_schema(blocks, fetch_if_missing)

def _get_unified_blocks_schema(
self, blocks: BlockList, fetch_if_missing: bool = False
) -> Union[type, "pyarrow.lib.Schema"]:
"""Get the unified schema of the blocks.

Args:
blocks: the blocks to get schema
fetch_if_missing: Whether to execute the blocks to fetch the schema.
"""

# Only trigger the execution of first block in case it's a lazy block list.
# Don't trigger full execution for a schema read.
Expand Down Expand Up @@ -313,7 +399,9 @@ def meta_count(self) -> Optional[int]:
if self._stages_after_snapshot:
return None
# Snapshot is now guaranteed to be the output of the final stage or None.
blocks = self._snapshot_blocks
return self._get_num_rows_from_blocks_metadata(self._snapshot_blocks)

def _get_num_rows_from_blocks_metadata(self, blocks: BlockList) -> Optional[int]:
metadata = blocks.get_metadata() if blocks else None
if metadata and all(m.num_rows is not None for m in metadata):
return sum(m.num_rows for m in metadata)
Expand Down
55 changes: 23 additions & 32 deletions python/ray/data/dataset.py
Expand Up @@ -177,19 +177,23 @@ class Dataset(Generic[T]):
>>> ds = ray.data.range(1000)
>>> # Transform in parallel with map_batches().
>>> ds.map_batches(lambda batch: [v * 2 for v in batch])
Dataset(num_blocks=..., num_rows=..., schema=...)
MapBatches
+- Dataset(num_blocks=17, num_rows=1000, schema=<class 'int'>)
>>> # Compute max.
>>> ds.max()
999
>>> # Group the data.
>>> ds.groupby(lambda x: x % 3).count()
Dataset(num_blocks=..., num_rows=..., schema=...)
Aggregate
+- Dataset(num_blocks=..., num_rows=1000, schema=<class 'int'>)
>>> # Shuffle this dataset randomly.
>>> ds.random_shuffle()
Dataset(num_blocks=..., num_rows=..., schema=...)
RandomShuffle
+- Dataset(num_blocks=..., num_rows=1000, schema=<class 'int'>)
>>> # Sort it back in order.
>>> ds.sort()
Dataset(num_blocks=..., num_rows=..., schema=...)
Sort
+- Dataset(num_blocks=..., num_rows=1000, schema=<class 'int'>)

Since Datasets are just lists of Ray object refs, they can be passed
between Ray tasks and actors without incurring a copy. Datasets support
Expand Down Expand Up @@ -241,7 +245,8 @@ def map(
>>> # Transform python objects.
>>> ds = ray.data.range(1000)
>>> ds.map(lambda x: x * 2)
Dataset(num_blocks=..., num_rows=..., schema=...)
Map
+- Dataset(num_blocks=..., num_rows=1000, schema=<class 'int'>)
>>> # Transform Arrow records.
>>> ds = ray.data.from_items(
... [{"value": i} for i in range(1000)])
Expand Down Expand Up @@ -795,7 +800,8 @@ def flat_map(
>>> import ray
>>> ds = ray.data.range(1000)
>>> ds.flat_map(lambda x: [x, x ** 2, x ** 3])
Dataset(num_blocks=..., num_rows=..., schema=...)
FlatMap
+- Dataset(num_blocks=..., num_rows=1000, schema=<class 'int'>)

Time complexity: O(dataset size / parallelism)

Expand Down Expand Up @@ -863,7 +869,8 @@ def filter(
>>> import ray
>>> ds = ray.data.range(100)
>>> ds.filter(lambda x: x % 2 == 0)
Dataset(num_blocks=..., num_rows=..., schema=...)
Filter
+- Dataset(num_blocks=..., num_rows=100, schema=<class 'int'>)

Time complexity: O(dataset size / parallelism)

Expand Down Expand Up @@ -957,10 +964,12 @@ def random_shuffle(
>>> ds = ray.data.range(100)
>>> # Shuffle this dataset randomly.
>>> ds.random_shuffle()
Dataset(num_blocks=..., num_rows=..., schema=...)
RandomShuffle
+- Dataset(num_blocks=..., num_rows=100, schema=<class 'int'>)
>>> # Shuffle this dataset with a fixed random seed.
>>> ds.random_shuffle(seed=12345)
Dataset(num_blocks=..., num_rows=..., schema=...)
RandomShuffle
+- Dataset(num_blocks=..., num_rows=100, schema=<class 'int'>)

Time complexity: O(dataset size / parallelism)

Expand Down Expand Up @@ -1524,7 +1533,8 @@ def groupby(self, key: Optional[KeyFn]) -> "GroupedDataset[T]":
>>> import ray
>>> # Group by a key function and aggregate.
>>> ray.data.range(100).groupby(lambda x: x % 3).count()
Dataset(num_blocks=..., num_rows=..., schema=...)
Aggregate
+- Dataset(num_blocks=..., num_rows=100, schema=<class 'int'>)
>>> # Group by an Arrow table column and aggregate.
>>> ray.data.from_items([
... {"A": x % 3, "B": x} for x in range(100)]).groupby(
Expand Down Expand Up @@ -1924,7 +1934,8 @@ def sort(
>>> # Sort using the entire record as the key.
>>> ds = ray.data.range(100)
>>> ds.sort()
Dataset(num_blocks=..., num_rows=..., schema=...)
Sort
+- Dataset(num_blocks=..., num_rows=100, schema=<class 'int'>)
>>> # Sort by a single column in descending order.
>>> ds = ray.data.from_items(
... [{"value": i} for i in range(1000)])
Expand Down Expand Up @@ -4217,27 +4228,7 @@ def _tab_repr_(self):
return Tab(children, titles=["Metadata", "Schema"])

def __repr__(self) -> str:
# Do not force execution for schema, as this method is expected to be very
# cheap.
schema = self.schema(fetch_if_missing=False)
if schema is None:
schema_str = "Unknown schema"
elif isinstance(schema, type):
schema_str = str(schema)
else:
schema_str = []
for n, t in zip(schema.names, schema.types):
if hasattr(t, "__name__"):
t = t.__name__
schema_str.append(f"{n}: {t}")
schema_str = ", ".join(schema_str)
schema_str = "{" + schema_str + "}"
count = self._meta_count()
if count is None:
count = "?"
return "Dataset(num_blocks={}, num_rows={}, schema={})".format(
self._plan.initial_num_blocks(), count, schema_str
)
return self._plan.get_plan_as_string()

def __str__(self) -> str:
return repr(self)
Expand Down
43 changes: 43 additions & 0 deletions python/ray/data/tests/test_dataset.py
Expand Up @@ -1498,6 +1498,49 @@ def test_lazy_loading_exponential_rampup(ray_start_regular_shared):
assert ds._plan.execute()._num_computed() == 20


def test_dataset_repr(ray_start_regular_shared):
ds = ray.data.range(10, parallelism=10)
assert repr(ds) == "Dataset(num_blocks=10, num_rows=10, schema=<class 'int'>)"
ds = ds.map_batches(lambda x: x)
assert repr(ds) == (
"MapBatches\n" "+- Dataset(num_blocks=10, num_rows=10, schema=<class 'int'>)"
)
ds = ds.filter(lambda x: x > 0)
assert repr(ds) == (
"Filter\n"
"+- MapBatches\n"
" +- Dataset(num_blocks=10, num_rows=10, schema=<class 'int'>)"
)
ds = ds.random_shuffle()
assert repr(ds) == (
"RandomShuffle\n"
"+- Filter\n"
" +- MapBatches\n"
" +- Dataset(num_blocks=10, num_rows=10, schema=<class 'int'>)"
)
ds.fully_executed()
assert repr(ds) == "Dataset(num_blocks=10, num_rows=9, schema=<class 'int'>)"
ds = ds.map_batches(lambda x: x)
assert repr(ds) == (
"MapBatches\n" "+- Dataset(num_blocks=10, num_rows=9, schema=<class 'int'>)"
)
ds1, ds2 = ds.split(2)
assert (
repr(ds1)
== f"Dataset(num_blocks=5, num_rows={ds1.count()}, schema=<class 'int'>)"
)
assert (
repr(ds2)
== f"Dataset(num_blocks=5, num_rows={ds2.count()}, schema=<class 'int'>)"
)
ds3 = ds1.union(ds2)
assert repr(ds3) == "Dataset(num_blocks=10, num_rows=9, schema=<class 'int'>)"
ds = ds.zip(ds3)
assert repr(ds) == (
"Zip\n" "+- Dataset(num_blocks=10, num_rows=9, schema=<class 'int'>)"
)


@pytest.mark.parametrize("lazy", [False, True])
def test_limit(ray_start_regular_shared, lazy):
ds = ray.data.range(100, parallelism=20)
Expand Down