Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[data] add AllToAllAPI to dataset methods #40842

Merged
merged 10 commits into from
Nov 2, 2023
Merged
Show file tree
Hide file tree
Changes from 7 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 26 additions & 0 deletions python/ray/data/_internal/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -452,6 +452,32 @@ def ConsumptionAPI(*args, **kwargs):
return _consumption_api(*args, **kwargs)


def _all_to_all_api(*args, **kwargs):
"""Annotate the function with an indication that it's a all to all API, and that it
is an operation that requires all inputs to be materialized in-memory to execute.
"""
def wrap(obj):
_insert_doc_at_pattern(
obj,
message="This operation requires all inputs to be materialized in object store for it to execute.",
pattern="Examples:",
insert_after=False,
directive="note",
)
return obj

return wrap


def AllToAllAPI(*args, **kwargs):
"""Annotate the function with an indication that it's a all to all API, and that it
is an operation that requires all inputs to be materialized in-memory to execute.
"""
if len(args) == 1 and len(kwargs) == 0 and callable(args[0]):
return _all_to_all_api()(args[0])
return _all_to_all_api(*args, **kwargs)


def _split_list(arr: List[Any], num_splits: int) -> List[List[Any]]:
"""Split the list into `num_splits` lists.

Expand Down
14 changes: 13 additions & 1 deletion python/ray/data/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,7 @@
DatasetStatsSummary,
get_dataset_id_from_stats_actor,
)
from ray.data._internal.util import ConsumptionAPI, _is_local_scheme, validate_compute
from ray.data._internal.util import AllToAllAPI, ConsumptionAPI, _is_local_scheme, validate_compute
from ray.data.aggregate import AggregateFn, Max, Mean, Min, Std, Sum
from ray.data.block import (
VALID_BATCH_FORMATS,
Expand Down Expand Up @@ -974,6 +974,7 @@ def filter(

return Dataset(plan, logical_plan)

@AllToAllAPI
def repartition(self, num_blocks: int, *, shuffle: bool = False) -> "Dataset":
"""Repartition the :class:`Dataset` into exactly this number of :ref:`blocks <dataset_concept>`.

Expand Down Expand Up @@ -1029,6 +1030,7 @@ def repartition(self, num_blocks: int, *, shuffle: bool = False) -> "Dataset":
logical_plan = LogicalPlan(op)
return Dataset(plan, logical_plan)

@AllToAllAPI
def random_shuffle(
self,
*,
Expand Down Expand Up @@ -1079,6 +1081,7 @@ def random_shuffle(
logical_plan = LogicalPlan(op)
return Dataset(plan, logical_plan)

@AllToAllAPI
def randomize_block_order(
self,
*,
Expand Down Expand Up @@ -1833,6 +1836,7 @@ def union(self, *other: List["Dataset"]) -> "Dataset":
logical_plan,
)

@AllToAllAPI
def groupby(self, key: Union[str, List[str], None]) -> "GroupedData":
"""Group rows of a :class:`Dataset` according to a column.

Expand Down Expand Up @@ -1880,6 +1884,7 @@ def normalize_variety(group: pd.DataFrame) -> pd.DataFrame:

return GroupedData(self, key)

@AllToAllAPI
def unique(self, column: str) -> List[Any]:
"""List the unique elements in a given column.

Expand Down Expand Up @@ -1921,6 +1926,7 @@ def unique(self, column: str) -> List[Any]:
ds = self.select_columns([column]).groupby(column).count()
return [item[column] for item in ds.take_all()]

@AllToAllAPI
@ConsumptionAPI
def aggregate(self, *aggs: AggregateFn) -> Union[Any, Dict[str, Any]]:
"""Aggregate values using one or more functions.
Expand Down Expand Up @@ -1958,6 +1964,7 @@ def aggregate(self, *aggs: AggregateFn) -> Union[Any, Dict[str, Any]]:
ret = self.groupby(None).aggregate(*aggs).take(1)
return ret[0] if len(ret) > 0 else None

@AllToAllAPI
@ConsumptionAPI
def sum(
self, on: Optional[Union[str, List[str]]] = None, ignore_nulls: bool = True
Expand Down Expand Up @@ -2000,6 +2007,7 @@ def sum(
ret = self._aggregate_on(Sum, on, ignore_nulls)
return self._aggregate_result(ret)

@AllToAllAPI
@ConsumptionAPI
def min(
self, on: Optional[Union[str, List[str]]] = None, ignore_nulls: bool = True
Expand Down Expand Up @@ -2042,6 +2050,7 @@ def min(
ret = self._aggregate_on(Min, on, ignore_nulls)
return self._aggregate_result(ret)

@AllToAllAPI
@ConsumptionAPI
def max(
self, on: Optional[Union[str, List[str]]] = None, ignore_nulls: bool = True
Expand Down Expand Up @@ -2084,6 +2093,7 @@ def max(
ret = self._aggregate_on(Max, on, ignore_nulls)
return self._aggregate_result(ret)

@AllToAllAPI
@ConsumptionAPI
def mean(
self, on: Optional[Union[str, List[str]]] = None, ignore_nulls: bool = True
Expand Down Expand Up @@ -2126,6 +2136,7 @@ def mean(
ret = self._aggregate_on(Mean, on, ignore_nulls)
return self._aggregate_result(ret)

@AllToAllAPI
@ConsumptionAPI
def std(
self,
Expand Down Expand Up @@ -2182,6 +2193,7 @@ def std(
ret = self._aggregate_on(Std, on, ignore_nulls, ddof=ddof)
return self._aggregate_result(ret)

@AllToAllAPI
def sort(
self,
key: Union[str, List[str], None] = None,
Expand Down
Loading