Skip to content

Commit

Permalink
[data] add support for multiple group keys in map_groups
Browse files Browse the repository at this point in the history
Signed-off-by: Kit Lee <wklee4993@gmail.com>
  • Loading branch information
wingkitlee0 committed Nov 1, 2023
1 parent df6fe4c commit 71bbb60
Show file tree
Hide file tree
Showing 2 changed files with 44 additions and 5 deletions.
19 changes: 14 additions & 5 deletions python/ray/data/grouped_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -293,16 +293,25 @@ def map_groups(
else:
sorted_ds = self._dataset.repartition(1)

# Returns the group boundaries.
def get_key_boundaries(block_accessor: BlockAccessor):
def get_key_boundaries(block_accessor: BlockAccessor) -> List[int]:
"""Compute block boundaries based on the key(s)"""

import numpy as np

boundaries = []
# Get the keys of the batch in numpy array format
keys = block_accessor.to_numpy(self._key)

if isinstance(keys, np.ndarray):
arr = keys
else:
first_key = next(iter(keys))
arr = np.empty(len(keys[first_key]), dtype=object)
arr[:] = [str(d) for d in zip(*keys.values())]

boundaries = []
start = 0
while start < keys.size:
end = start + np.searchsorted(keys[start:], keys[start], side="right")
while start < len(arr):
end = start + np.searchsorted(arr[start:], arr[start], side="right")
boundaries.append(end)
start = end
return boundaries
Expand Down
30 changes: 30 additions & 0 deletions python/ray/data/tests/test_all_to_all.py
Original file line number Diff line number Diff line change
Expand Up @@ -361,6 +361,36 @@ def test_groupby_multiple_keys_tabular_count(
]


@pytest.mark.parametrize("num_parts", [1, 30])
@pytest.mark.parametrize("ds_format", ["pyarrow", "pandas", "numpy"])
def test_groupby_multiple_keys_map_groups(
ray_start_regular_shared, ds_format, num_parts, use_push_based_shuffle
):
# Test built-in count aggregation
print(f"Seeding RNG for test_groupby_arrow_count with: {RANDOM_SEED}")
random.seed(RANDOM_SEED)
xs = list(range(100))
random.shuffle(xs)

ds = ray.data.from_items([{"A": (x % 2), "B": (x % 3)} for x in xs]).repartition(
num_parts
)
ds = ds.map_batches(lambda x: x, batch_size=None, batch_format=ds_format)

agg_ds = ds.groupby(["A", "B"]).map_groups(
lambda df: {"count": [len(df["A"])]}, batch_format=ds_format
)
assert agg_ds.count() == 6
assert agg_ds.take_all() == [
{"count": 17},
{"count": 16},
{"count": 17},
{"count": 17},
{"count": 17},
{"count": 16},
]


@pytest.mark.parametrize("num_parts", [1, 30])
@pytest.mark.parametrize("ds_format", ["arrow", "pandas"])
def test_groupby_tabular_sum(
Expand Down

0 comments on commit 71bbb60

Please sign in to comment.