Skip to content

Commit

Permalink
[data] add AllToAllAPI to dataset methods (#40842)
Browse files Browse the repository at this point in the history
Adds an `AllToAllAPI` decorator for `Dataset` methods that create all-to-all operators.

The list:

- repartition
- random_shuffle
- randomize_block_order
- groupby
- unique
- aggregate
- sum
- min
- max
- mean
- std
- sort

Please let me know if I missed any

Signed-off-by: Andrew Xue <andewzxue@gmail.com>
Co-authored-by: Cheng Su <scnju13@gmail.com>
  • Loading branch information
Zandew and c21 committed Nov 2, 2023
1 parent 0588b17 commit 45ac625
Show file tree
Hide file tree
Showing 2 changed files with 48 additions and 1 deletion.
30 changes: 30 additions & 0 deletions python/ray/data/_internal/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -452,6 +452,36 @@ def ConsumptionAPI(*args, **kwargs):
return _consumption_api(*args, **kwargs)


def _all_to_all_api(*args, **kwargs):
"""Annotate the function with an indication that it's a all to all API, and that it
is an operation that requires all inputs to be materialized in-memory to execute.
"""

def wrap(obj):
_insert_doc_at_pattern(
obj,
message=(
"This operation requires all inputs to be "
"materialized in object store for it to execute."
),
pattern="Examples:",
insert_after=False,
directive="note",
)
return obj

return wrap


def AllToAllAPI(*args, **kwargs):
"""Annotate the function with an indication that it's a all to all API, and that it
is an operation that requires all inputs to be materialized in-memory to execute.
"""
# This should only be used as a decorator for dataset methods.
assert len(args) == 1 and len(kwargs) == 0 and callable(args[0])
return _all_to_all_api()(args[0])


def _split_list(arr: List[Any], num_splits: int) -> List[List[Any]]:
"""Split the list into `num_splits` lists.
Expand Down
19 changes: 18 additions & 1 deletion python/ray/data/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,12 @@
DatasetStatsSummary,
get_dataset_id_from_stats_actor,
)
from ray.data._internal.util import ConsumptionAPI, _is_local_scheme, validate_compute
from ray.data._internal.util import (
AllToAllAPI,
ConsumptionAPI,
_is_local_scheme,
validate_compute,
)
from ray.data.aggregate import AggregateFn, Max, Mean, Min, Std, Sum
from ray.data.block import (
VALID_BATCH_FORMATS,
Expand Down Expand Up @@ -974,6 +979,7 @@ def filter(

return Dataset(plan, logical_plan)

@AllToAllAPI
def repartition(self, num_blocks: int, *, shuffle: bool = False) -> "Dataset":
"""Repartition the :class:`Dataset` into exactly this number of :ref:`blocks <dataset_concept>`.
Expand Down Expand Up @@ -1029,6 +1035,7 @@ def repartition(self, num_blocks: int, *, shuffle: bool = False) -> "Dataset":
logical_plan = LogicalPlan(op)
return Dataset(plan, logical_plan)

@AllToAllAPI
def random_shuffle(
self,
*,
Expand Down Expand Up @@ -1079,6 +1086,7 @@ def random_shuffle(
logical_plan = LogicalPlan(op)
return Dataset(plan, logical_plan)

@AllToAllAPI
def randomize_block_order(
self,
*,
Expand Down Expand Up @@ -1833,6 +1841,7 @@ def union(self, *other: List["Dataset"]) -> "Dataset":
logical_plan,
)

@AllToAllAPI
def groupby(self, key: Union[str, List[str], None]) -> "GroupedData":
"""Group rows of a :class:`Dataset` according to a column.
Expand Down Expand Up @@ -1880,6 +1889,7 @@ def normalize_variety(group: pd.DataFrame) -> pd.DataFrame:

return GroupedData(self, key)

@AllToAllAPI
def unique(self, column: str) -> List[Any]:
"""List the unique elements in a given column.
Expand Down Expand Up @@ -1921,6 +1931,7 @@ def unique(self, column: str) -> List[Any]:
ds = self.select_columns([column]).groupby(column).count()
return [item[column] for item in ds.take_all()]

@AllToAllAPI
@ConsumptionAPI
def aggregate(self, *aggs: AggregateFn) -> Union[Any, Dict[str, Any]]:
"""Aggregate values using one or more functions.
Expand Down Expand Up @@ -1958,6 +1969,7 @@ def aggregate(self, *aggs: AggregateFn) -> Union[Any, Dict[str, Any]]:
ret = self.groupby(None).aggregate(*aggs).take(1)
return ret[0] if len(ret) > 0 else None

@AllToAllAPI
@ConsumptionAPI
def sum(
self, on: Optional[Union[str, List[str]]] = None, ignore_nulls: bool = True
Expand Down Expand Up @@ -2000,6 +2012,7 @@ def sum(
ret = self._aggregate_on(Sum, on, ignore_nulls)
return self._aggregate_result(ret)

@AllToAllAPI
@ConsumptionAPI
def min(
self, on: Optional[Union[str, List[str]]] = None, ignore_nulls: bool = True
Expand Down Expand Up @@ -2042,6 +2055,7 @@ def min(
ret = self._aggregate_on(Min, on, ignore_nulls)
return self._aggregate_result(ret)

@AllToAllAPI
@ConsumptionAPI
def max(
self, on: Optional[Union[str, List[str]]] = None, ignore_nulls: bool = True
Expand Down Expand Up @@ -2084,6 +2098,7 @@ def max(
ret = self._aggregate_on(Max, on, ignore_nulls)
return self._aggregate_result(ret)

@AllToAllAPI
@ConsumptionAPI
def mean(
self, on: Optional[Union[str, List[str]]] = None, ignore_nulls: bool = True
Expand Down Expand Up @@ -2126,6 +2141,7 @@ def mean(
ret = self._aggregate_on(Mean, on, ignore_nulls)
return self._aggregate_result(ret)

@AllToAllAPI
@ConsumptionAPI
def std(
self,
Expand Down Expand Up @@ -2182,6 +2198,7 @@ def std(
ret = self._aggregate_on(Std, on, ignore_nulls, ddof=ddof)
return self._aggregate_result(ret)

@AllToAllAPI
def sort(
self,
key: Union[str, List[str], None] = None,
Expand Down

0 comments on commit 45ac625

Please sign in to comment.