Skip to content

Commit

Permalink
fix(rust, python): group_by partitioned with literal Series panic (
Browse files Browse the repository at this point in the history
  • Loading branch information
CanglongCl committed Apr 7, 2024
1 parent cc6c642 commit eda3ccd
Show file tree
Hide file tree
Showing 2 changed files with 34 additions and 10 deletions.
Expand Up @@ -69,19 +69,23 @@ fn run_partitions(
state: &ExecutionState,
n_threads: usize,
maintain_order: bool,
) -> PolarsResult<Vec<DataFrame>> {
) -> PolarsResult<(Vec<DataFrame>, Vec<Vec<Series>>)> {
// We do a partitioned group_by.
// Meaning that we first do the group_by operation arbitrarily
// split on several threads. Than the final result we apply the same group_by again.
let dfs = split_df(df, n_threads)?;

let phys_aggs = &exec.phys_aggs;
let keys = &exec.phys_keys;

let mut keys = DataFrame::from_iter(compute_keys(keys, df, state)?);
let splitted_keys = split_df(&mut keys, n_threads)?;

POOL.install(|| {
dfs.into_par_iter()
.map(|df| {
let keys = compute_keys(keys, &df, state)?;
let gb = df.group_by_with_series(keys, false, maintain_order)?;
.zip(splitted_keys)
.map(|(df, keys)| {
let gb = df.group_by_with_series(keys.into(), false, maintain_order)?;
let groups = gb.get_groups();

let mut columns = gb.keys();
Expand All @@ -106,7 +110,8 @@ fn run_partitions(

columns.extend_from_slice(&agg_columns);

DataFrame::new(columns)
let df = DataFrame::new(columns)?;
Ok((df, gb.keys()))
})
.collect()
})
Expand Down Expand Up @@ -297,7 +302,7 @@ impl PartitionGroupByExec {
state: &mut ExecutionState,
mut original_df: DataFrame,
) -> PolarsResult<DataFrame> {
let dfs = {
let (splitted_dfs, splitted_keys) = {
// already get the keys. This is the very last minute decision which group_by method we choose.
// If the column is a categorical, we know the number of groups we have and can decide to continue
// partitioned or go for the standard group_by. The partitioned is likely to be faster on a small number
Expand Down Expand Up @@ -339,12 +344,23 @@ impl PartitionGroupByExec {
)?
};

state.set_schema(self.output_schema.clone());
// MERGE phase
// merge and hash aggregate again
let df = accumulate_dataframes_vertical(dfs)?;

let df = accumulate_dataframes_vertical(splitted_dfs)?;
let keys = splitted_keys
.into_iter()
.reduce(|mut acc, e| {
acc.iter_mut().zip(e).for_each(|(acc, e)| {
let _ = acc.append(&e);
});
acc
})
.unwrap();

// the partitioned group_by has added columns so we must update the schema.
let keys = self.keys(&df, state)?;
state.set_schema(self.output_schema.clone());

// merge and hash aggregate again

// first get mutable access and optionally sort
let gb = df.group_by_with_series(keys, true, self.maintain_order)?;
Expand Down
8 changes: 8 additions & 0 deletions py-polars/tests/unit/operations/test_group_by.py
Expand Up @@ -9,6 +9,7 @@
import polars as pl
import polars.selectors as cs
from polars.testing import assert_frame_equal, assert_series_equal
from polars.testing._constants import PARTITION_LIMIT

if TYPE_CHECKING:
from polars.type_aliases import PolarsDataType
Expand Down Expand Up @@ -768,6 +769,13 @@ def test_group_by_partitioned_ending_cast(monkeypatch: Any) -> None:
assert_frame_equal(out, expected)


def test_group_by_series_partitioned() -> None:
# test 15354
df = pl.DataFrame([0, 0] * PARTITION_LIMIT)
groups = pl.Series([0, 1] * PARTITION_LIMIT)
df.group_by(groups).agg(pl.all().is_not_null().sum())


def test_groupby_deprecated() -> None:
df = pl.DataFrame({"a": [1, 1, 2], "b": [3, 4, 5]})

Expand Down

0 comments on commit eda3ccd

Please sign in to comment.