Skip to content

Commit

Permalink
fix(rust, python): fix boolean schema in agg_max/min (#5678)
Browse files Browse the repository at this point in the history
  • Loading branch information
ritchie46 committed Nov 30, 2022
1 parent ef53b62 commit 43ed5c9
Show file tree
Hide file tree
Showing 4 changed files with 150 additions and 24 deletions.
15 changes: 15 additions & 0 deletions polars/polars-core/src/chunked_array/upstream_traits.rs
Original file line number Diff line number Diff line change
Expand Up @@ -435,6 +435,21 @@ impl FromParallelIterator<bool> for BooleanChunked {
}
}

impl FromParallelIterator<Option<bool>> for BooleanChunked {
fn from_par_iter<I: IntoParallelIterator<Item = Option<bool>>>(iter: I) -> Self {
let vectors = collect_into_linked_list(iter);

let capacity: usize = get_capacity_from_par_results(&vectors);

let arr = unsafe {
BooleanArray::from_trusted_len_iter(
vectors.into_iter().flatten().trust_my_length(capacity),
)
};
Self::from_chunks("", vec![Box::new(arr)])
}
}

impl<Ptr> FromParallelIterator<Ptr> for Utf8Chunked
where
Ptr: PolarsAsRef<str> + Send + Sync,
Expand Down
97 changes: 95 additions & 2 deletions polars/polars-core/src/frame/groupby/aggregations/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -161,6 +161,14 @@ where
ca.into_series()
}

pub fn _agg_helper_idx_bool<F>(groups: &GroupsIdx, f: F) -> Series
where
F: Fn((IdxSize, &Vec<IdxSize>)) -> Option<bool> + Send + Sync,
{
let ca: BooleanChunked = POOL.install(|| groups.into_par_iter().map(f).collect());
ca.into_series()
}

// helper that iterates on the `all: Vec<Vec<u32>` collection
// this doesn't have traverse the `first: Vec<u32>` memory and is therefore faster
fn agg_helper_idx_on_all<T, F>(groups: &GroupsIdx, f: F) -> Series
Expand All @@ -183,12 +191,97 @@ where
ca.into_series()
}

pub fn _agg_helper_slice_bool<F>(groups: &[[IdxSize; 2]], f: F) -> Series
where
F: Fn([IdxSize; 2]) -> Option<bool> + Send + Sync,
{
let ca: BooleanChunked = POOL.install(|| groups.par_iter().copied().map(f).collect());
ca.into_series()
}

impl BooleanChunked {
pub(crate) unsafe fn agg_min(&self, groups: &GroupsProxy) -> Series {
self.cast(&IDX_DTYPE).unwrap().agg_min(groups)
// faster paths
match (self.is_sorted2(), self.null_count()) {
(IsSorted::Ascending, 0) => {
return self.clone().into_series().agg_first(groups);
}
(IsSorted::Descending, 0) => {
return self.clone().into_series().agg_last(groups);
}
_ => {}
}
match groups {
GroupsProxy::Idx(groups) => _agg_helper_idx_bool(groups, |(first, idx)| {
debug_assert!(idx.len() <= self.len());
if idx.is_empty() {
None
} else if idx.len() == 1 {
self.get(first as usize)
} else {
// TODO! optimize this
// can just check if any is false and early stop
let take = { self.take_unchecked(idx.into()) };
take.min().map(|v| v == 1)
}
}),
GroupsProxy::Slice {
groups: groups_slice,
..
} => _agg_helper_slice_bool(groups_slice, |[first, len]| {
debug_assert!(len <= self.len() as IdxSize);
match len {
0 => None,
1 => self.get(first as usize),
_ => {
let arr_group = _slice_from_offsets(self, first, len);
arr_group.min().map(|v| v == 1)
}
}
}),
}
}
pub(crate) unsafe fn agg_max(&self, groups: &GroupsProxy) -> Series {
self.cast(&IDX_DTYPE).unwrap().agg_max(groups)
// faster paths
match (self.is_sorted2(), self.null_count()) {
(IsSorted::Ascending, 0) => {
return self.clone().into_series().agg_last(groups);
}
(IsSorted::Descending, 0) => {
return self.clone().into_series().agg_first(groups);
}
_ => {}
}

match groups {
GroupsProxy::Idx(groups) => _agg_helper_idx_bool(groups, |(first, idx)| {
debug_assert!(idx.len() <= self.len());
if idx.is_empty() {
None
} else if idx.len() == 1 {
self.get(first as usize)
} else {
// TODO! optimize this
// can just check if any is true and early stop
let take = { self.take_unchecked(idx.into()) };
take.max().map(|v| v == 1)
}
}),
GroupsProxy::Slice {
groups: groups_slice,
..
} => _agg_helper_slice_bool(groups_slice, |[first, len]| {
debug_assert!(len <= self.len() as IdxSize);
match len {
0 => None,
1 => self.get(first as usize),
_ => {
let arr_group = _slice_from_offsets(self, first, len);
arr_group.max().map(|v| v == 1)
}
}
}),
}
}
pub(crate) unsafe fn agg_sum(&self, groups: &GroupsProxy) -> Series {
self.cast(&IDX_DTYPE).unwrap().agg_sum(groups)
Expand Down
44 changes: 22 additions & 22 deletions py-polars/polars/internals/dataframe/groupby.py
Original file line number Diff line number Diff line change
Expand Up @@ -600,17 +600,17 @@ def min(self) -> pli.DataFrame:
... )
>>> df.groupby("d", maintain_order=True).min()
shape: (3, 4)
┌────────┬─────┬──────┬─────┐
│ d ┆ a ┆ b ┆ c │
│ --- ┆ --- ┆ --- ┆ --- │
│ str ┆ i64 ┆ f64 ┆ u32
╞════════╪═════╪══════╪═════╡
│ Apple ┆ 1 ┆ 0.5 ┆ 0
├╌╌╌╌╌╌╌╌┼╌╌╌╌╌┼╌╌╌╌╌╌┼╌╌╌╌╌┤
│ Orange ┆ 2 ┆ 0.5 ┆ 1
├╌╌╌╌╌╌╌╌┼╌╌╌╌╌┼╌╌╌╌╌╌┼╌╌╌╌╌┤
│ Banana ┆ 4 ┆ 13.0 ┆ 0
└────────┴─────┴──────┴─────┘
┌────────┬─────┬──────┬───────
│ d ┆ a ┆ b ┆ c
│ --- ┆ --- ┆ --- ┆ ---
│ str ┆ i64 ┆ f64 ┆ bool
╞════════╪═════╪══════╪═══════
│ Apple ┆ 1 ┆ 0.5 ┆ false
├╌╌╌╌╌╌╌╌┼╌╌╌╌╌┼╌╌╌╌╌╌┼╌╌╌╌╌╌╌
│ Orange ┆ 2 ┆ 0.5 ┆ true
├╌╌╌╌╌╌╌╌┼╌╌╌╌╌┼╌╌╌╌╌╌┼╌╌╌╌╌╌╌
│ Banana ┆ 4 ┆ 13.0 ┆ false
└────────┴─────┴──────┴───────
"""
return self.agg(pli.all().min())
Expand All @@ -631,17 +631,17 @@ def max(self) -> pli.DataFrame:
... )
>>> df.groupby("d", maintain_order=True).max()
shape: (3, 4)
┌────────┬─────┬──────┬─────┐
│ d ┆ a ┆ b ┆ c │
│ --- ┆ --- ┆ --- ┆ --- │
│ str ┆ i64 ┆ f64 ┆ u32
╞════════╪═════╪══════╪═════╡
│ Apple ┆ 3 ┆ 10.0 ┆ 1
├╌╌╌╌╌╌╌╌┼╌╌╌╌╌┼╌╌╌╌╌╌┼╌╌╌╌╌┤
│ Orange ┆ 2 ┆ 0.5 ┆ 1
├╌╌╌╌╌╌╌╌┼╌╌╌╌╌┼╌╌╌╌╌╌┼╌╌╌╌╌┤
│ Banana ┆ 5 ┆ 14.0 ┆ 1
└────────┴─────┴──────┴─────┘
┌────────┬─────┬──────┬─────
│ d ┆ a ┆ b ┆ c
│ --- ┆ --- ┆ --- ┆ ---
│ str ┆ i64 ┆ f64 ┆ bool
╞════════╪═════╪══════╪═════
│ Apple ┆ 3 ┆ 10.0 ┆ true
├╌╌╌╌╌╌╌╌┼╌╌╌╌╌┼╌╌╌╌╌╌┼╌╌╌╌╌
│ Orange ┆ 2 ┆ 0.5 ┆ true
├╌╌╌╌╌╌╌╌┼╌╌╌╌╌┼╌╌╌╌╌╌┼╌╌╌╌╌
│ Banana ┆ 5 ┆ 14.0 ┆ true
└────────┴─────┴──────┴─────
"""
return self.agg(pli.all().max())
Expand Down
18 changes: 18 additions & 0 deletions py-polars/tests/unit/test_schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -249,3 +249,21 @@ def test_diff_duration_dtype() -> None:
False,
True,
]


def test_boolean_agg_schema() -> None:
df = pl.DataFrame(
{
"x": [1, 1, 1],
"y": [False, True, False],
}
).lazy()

agg_df = df.groupby("x").agg(pl.col("y").max().alias("max_y"))

for streaming in [True, False]:
assert (
agg_df.collect(streaming=streaming).schema
== agg_df.schema
== {"x": pl.Int64, "max_y": pl.Boolean}
)

0 comments on commit 43ed5c9

Please sign in to comment.