Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[data] Implement zero-copy fusion for Read op #38789

Merged
merged 10 commits into from
Aug 23, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
103 changes: 85 additions & 18 deletions python/ray/data/_internal/execution/operators/map_transformer.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import itertools
from abc import abstractmethod
from enum import Enum
from typing import Any, Callable, Dict, Iterable, List, Optional, TypeVar, Union

Expand Down Expand Up @@ -31,7 +32,6 @@ class MapTransformFn:

def __init__(
self,
callable: MapTransformCallable[MapTransformFnData, MapTransformFnData],
input_type: MapTransformFnDataType,
output_type: MapTransformFnDataType,
):
Expand All @@ -45,10 +45,11 @@ def __init__(
self._input_type = input_type
self._output_type = output_type

@abstractmethod
def __call__(
self, input: Iterable[MapTransformFnData], ctx: TaskContext
) -> Iterable[MapTransformFnData]:
return self._callable(input, ctx)
...

@property
def input_type(self) -> MapTransformFnDataType:
Expand Down Expand Up @@ -80,6 +81,11 @@ def __init__(
init_fn: A function that will be called before transforming data.
Used for the actor-based map operator.
"""
self.set_transform_fns(transform_fns)
self._init_fn = init_fn if init_fn is not None else lambda: None

def set_transform_fns(self, transform_fns: List[MapTransformFn]) -> None:
"""Set the transform functions."""
assert len(transform_fns) > 0
assert (
transform_fns[0].input_type == MapTransformFnDataType.Block
Expand All @@ -93,9 +99,11 @@ def __init__(
"The output type of the previous transform function must match "
"the input type of the next transform function."
)

self._transform_fns = transform_fns
self._init_fn = init_fn if init_fn is not None else lambda: None

def get_transform_fns(self) -> List[MapTransformFn]:
"""Get the transform functions."""
return self._transform_fns

def init(self) -> None:
"""Initialize the transformer.
Expand Down Expand Up @@ -140,32 +148,78 @@ def create_map_transformer_from_block_fn(
"""
return MapTransformer(
[
MapTransformFn(
block_fn,
MapTransformFnDataType.Block,
MapTransformFnDataType.Block,
)
BlockMapTransformFn(block_fn),
],
init_fn,
)


# Below are util `MapTransformFn`s for converting input/output data.
# Below are subclasses of MapTransformFn.


class RowMapTransformFn(MapTransformFn):
"""A rows-to-rows MapTransformFn."""

def __init__(self, row_fn: MapTransformCallable[Row, Row]):
self._row_fn = row_fn
super().__init__(
MapTransformFnDataType.Row,
MapTransformFnDataType.Row,
)

def __call__(self, input: Iterable[Row], ctx: TaskContext) -> Iterable[Row]:
yield from self._row_fn(input, ctx)

def __repr__(self) -> str:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

where is this being used?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Currently not being used. I only used this when debugging, and decided to keep it as having a better repr won't hurt.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

now it's used in the unit test.

return f"RowMapTransformFn({self._row_fn})"


class BatchMapTransformFn(MapTransformFn):
"""A batch-to-batch MapTransformFn."""

def __init__(self, batch_fn: MapTransformCallable[DataBatch, DataBatch]):
self._batch_fn = batch_fn
super().__init__(
MapTransformFnDataType.Batch,
MapTransformFnDataType.Batch,
)

def __call__(
self, input: Iterable[DataBatch], ctx: TaskContext
) -> Iterable[DataBatch]:
yield from self._batch_fn(input, ctx)

def __repr__(self) -> str:
return f"BatchMapTransformFn({self._batch_fn})"


class BlockMapTransformFn(MapTransformFn):
"""A block-to-block MapTransformFn."""

def __init__(self, block_fn: MapTransformCallable[Block, Block]):
self._block_fn = block_fn
super().__init__(
MapTransformFnDataType.Block,
MapTransformFnDataType.Block,
)

def __call__(self, input: Iterable[Block], ctx: TaskContext) -> Iterable[Block]:
yield from self._block_fn(input, ctx)

def __repr__(self) -> str:
return f"BlockMapTransformFn({self._block_fn})"


class BlocksToRowsMapTransformFn(MapTransformFn):
"""A MapTransformFn that converts input blocks to rows."""

def __init__(self):
super().__init__(
self._input_blocks_to_rows,
MapTransformFnDataType.Block,
MapTransformFnDataType.Row,
)

def _input_blocks_to_rows(
self, blocks: Iterable[Block], _: TaskContext
) -> Iterable[Row]:
def __call__(self, blocks: Iterable[Block], _: TaskContext) -> Iterable[Row]:
for block in blocks:
block = BlockAccessor.for_block(block)
for row in block.iter_rows(public_row_format=True):
Expand All @@ -178,6 +232,9 @@ def instance(cls) -> "BlocksToRowsMapTransformFn":
cls._instance = cls()
return cls._instance

def __repr__(self) -> str:
return "BlocksToRowsMapTransformFn()"


class BlocksToBatchesMapTransformFn(MapTransformFn):
"""A MapTransformFn that converts input blocks to batches."""
Expand All @@ -192,12 +249,11 @@ def __init__(
self._batch_format = batch_format
self._ensure_copy = not zero_copy_batch and batch_size is not None
super().__init__(
self._input_blocks_to_batches,
MapTransformFnDataType.Block,
MapTransformFnDataType.Batch,
)

def _input_blocks_to_batches(
def __call__(
self,
blocks: Iterable[Block],
_: TaskContext,
Expand Down Expand Up @@ -241,6 +297,15 @@ def batch_format(self) -> str:
def zero_copy_batch(self) -> bool:
return not self._ensure_copy

def __repr__(self) -> str:
return (
f"BlocksToBatchesMapTransformFn("
f"batch_size={self._batch_size}, "
f"batch_format={self._batch_format}, "
f"zero_copy_batch={self.zero_copy_batch}"
f")"
)


class BuildOutputBlocksMapTransformFn(MapTransformFn):
"""A MapTransformFn that converts UDF-returned data to output blocks."""
Expand All @@ -252,12 +317,11 @@ def __init__(self, input_type: MapTransformFnDataType):
"""
self._input_type = input_type
super().__init__(
self._to_output_blocks,
input_type,
MapTransformFnDataType.Block,
)

def _to_output_blocks(
def __call__(
self,
iter: Iterable[MapTransformFnData],
_: TaskContext,
Expand Down Expand Up @@ -306,3 +370,6 @@ def for_blocks(cls) -> "BuildOutputBlocksMapTransformFn":
if getattr(cls, "_instance_for_blocks", None) is None:
cls._instance_for_blocks = cls(MapTransformFnDataType.Block)
return cls._instance_for_blocks

def __repr__(self) -> str:
return f"BuildOutputBlocksMapTransformFn(input_type={self._input_type})"
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
from ray.data._internal.logical.rules.operator_fusion import OperatorFusionRule
from ray.data._internal.logical.rules.randomize_blocks import ReorderRandomizeBlocksRule
from ray.data._internal.logical.rules.zero_copy_map_fusion import (
EliminateBuildOutputBlocks,
)


def get_logical_optimizer_rules():
Expand All @@ -8,5 +11,7 @@ def get_logical_optimizer_rules():


def get_physical_optimizer_rules():
rules = [OperatorFusionRule]
# Subclasses of ZeroCopyMapFusionRule (e.g., EliminateBuildOutputBlocks) should
# be run after OperatorFusionRule.
rules = [OperatorFusionRule, EliminateBuildOutputBlocks]
return rules
88 changes: 88 additions & 0 deletions python/ray/data/_internal/logical/rules/zero_copy_map_fusion.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,88 @@
from abc import abstractmethod
from typing import List

from ray.data._internal.execution.operators.map_operator import MapOperator
from ray.data._internal.execution.operators.map_transformer import (
BuildOutputBlocksMapTransformFn,
MapTransformFn,
MapTransformFnDataType,
)
from ray.data._internal.logical.interfaces.optimizer import Rule
from ray.data._internal.logical.interfaces.physical_plan import PhysicalPlan


class ZeroCopyMapFusionRule(Rule):
"""Base abstract class for all zero-copy map fusion rules.

A zero-copy map fusion rule is a rule that optimizes the transform_fn chain of
a fused MapOperator. The optimization is usually done by removing unnecessary
data conversions.

This base abstract class defines the common util functions. And subclasses
should implement the `_optimize` method for the concrete optimization
strategy.
"""

def apply(self, plan: PhysicalPlan) -> PhysicalPlan:
self._traverse(plan.dag)
return plan

def _traverse(self, op):
raulchen marked this conversation as resolved.
Show resolved Hide resolved
"""Traverse the DAG and apply the optimization to each MapOperator."""
if isinstance(op, MapOperator):
map_transformer = op.get_map_transformer()
transform_fns = map_transformer.get_transform_fns()
new_transform_fns = self._optimize(transform_fns)
# Physical operators won't be shared,
# so it's safe to modify the transform_fns in place.
map_transformer.set_transform_fns(new_transform_fns)

for input_op in op.input_dependencies:
self._traverse(input_op)

@abstractmethod
def _optimize(self, transform_fns: List[MapTransformFn]) -> List[MapTransformFn]:
"""Optimize the transform_fns chain of a MapOperator.

Args:
transform_fns: The old transform_fns chain.
Returns:
The optimized transform_fns chain.
"""
...


class EliminateBuildOutputBlocks(ZeroCopyMapFusionRule):
"""This rule eliminates unnecessary BuildOutputBlocksMapTransformFn,
if the previous fn already outputs blocks.

This happens for the "Read -> Map/Write" fusion.
"""

def _optimize(self, transform_fns: List[MapTransformFn]) -> List[MapTransformFn]:
# For the following subsquence,
# 1. Any MapTransformFn with block output.
# 2. BuildOutputBlocksMapTransformFn
# 3. Any MapTransformFn with block input.
# We drop the BuildOutputBlocksMapTransformFn in the middle.
new_transform_fns = []

for i in range(len(transform_fns)):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
for i in range(len(transform_fns)):
for i in range(1, len(transform_fns) - 1):

Nit

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is intentional. Because the fist and last transform_fns also need to be added to the result.

cur_fn = transform_fns[i]
drop = False
if (
i > 0
and i < len(transform_fns) - 1
and isinstance(cur_fn, BuildOutputBlocksMapTransformFn)
):
prev_fn = transform_fns[i - 1]
next_fn = transform_fns[i + 1]
if (
prev_fn.output_type == MapTransformFnDataType.Block
and next_fn.input_type == MapTransformFnDataType.Block
):
drop = True
if not drop:
new_transform_fns.append(cur_fn)

return new_transform_fns
23 changes: 9 additions & 14 deletions python/ray/data/_internal/planner/plan_read_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,9 @@
from ray.data._internal.execution.operators.input_data_buffer import InputDataBuffer
from ray.data._internal.execution.operators.map_operator import MapOperator
from ray.data._internal.execution.operators.map_transformer import (
BlockMapTransformFn,
BuildOutputBlocksMapTransformFn,
MapTransformer,
MapTransformFn,
MapTransformFnDataType,
)
from ray.data._internal.logical.operators.read_operator import Read
from ray.data.block import Block, BlockAccessor
Expand Down Expand Up @@ -110,23 +109,19 @@ def do_read(blocks: Iterable[ReadTask], _: TaskContext) -> Iterable[Block]:
# Create a MapTransformer for a read operator
transform_fns = [
# First, execute the read tasks.
MapTransformFn(
do_read, MapTransformFnDataType.Block, MapTransformFnDataType.Block
),
BlockMapTransformFn(do_read),
# Then build the output blocks.
BuildOutputBlocksMapTransformFn.for_blocks(),
]

if op._additional_split_factor is not None:
# If addtional split is needed, do it in the last.
transform_fns.append(
MapTransformFn(
BlockMapTransformFn(
functools.partial(
_do_additional_splits,
additional_output_splits=op._additional_split_factor,
),
MapTransformFnDataType.Block,
MapTransformFnDataType.Block,
)
),
)

Expand All @@ -148,17 +143,17 @@ def apply_output_blocks_handling_to_read_task(

This function is only used for compability with the legacy LazyBlockList code path.
"""
transform_fns: List[MapTransformFn] = [BuildOutputBlocksMapTransformFn.for_blocks()]
transform_fns: List[BlockMapTransformFn] = [
BuildOutputBlocksMapTransformFn.for_blocks()
]

if additional_split_factor is not None:
transform_fns.append(
MapTransformFn(
BlockMapTransformFn(
functools.partial(
_do_additional_splits,
additional_output_splits=additional_split_factor,
),
MapTransformFnDataType.Block,
MapTransformFnDataType.Block,
)
),
)
map_transformer = MapTransformer(transform_fns)
Expand Down
Loading
Loading