Skip to content

Commit

Permalink
fast path for sorted min/max (#4228)
Browse files Browse the repository at this point in the history
  • Loading branch information
ritchie46 committed Aug 2, 2022
1 parent 2dbc6bf commit 055d1d3
Show file tree
Hide file tree
Showing 7 changed files with 195 additions and 47 deletions.
44 changes: 18 additions & 26 deletions polars/polars-core/src/chunked_array/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
use crate::prelude::*;
use arrow::{array::*, bitmap::Bitmap};
use polars_arrow::prelude::ValueSize;
use std::iter::Map;
use std::marker::PhantomData;
use std::sync::Arc;

Expand Down Expand Up @@ -45,9 +46,10 @@ use polars_arrow::prelude::*;
#[cfg(feature = "dtype-categorical")]
use crate::chunked_array::categorical::RevMapping;
use crate::series::IsSorted;
use crate::utils::CustomIterTools;
use crate::utils::{first_non_null, last_non_null, CustomIterTools};
use bitflags::bitflags;
use std::mem;
use std::slice::Iter;

#[cfg(not(feature = "dtype-categorical"))]
pub struct RevMapping {}
Expand Down Expand Up @@ -207,26 +209,26 @@ impl<T: PolarsDataType> ChunkedArray<T> {

/// Get the index of the first non null value in this ChunkedArray.
pub fn first_non_null(&self) -> Option<usize> {
let mut offset = 0;
for validity in self.iter_validities() {
if let Some(validity) = validity {
for (idx, is_valid) in validity.iter().enumerate() {
if is_valid {
return Some(offset + idx);
}
}
offset += validity.len()
} else {
return Some(offset);
}
if self.is_empty() {
None
} else {
first_non_null(self.iter_validities())
}
None
}

/// Get the index of the last non null value in this ChunkedArray.
pub fn last_non_null(&self) -> Option<usize> {
last_non_null(self.iter_validities(), self.length as usize)
}

/// Get the buffer of bits representing null values
#[inline]
pub fn iter_validities(&self) -> impl Iterator<Item = Option<&Bitmap>> + '_ {
self.chunks.iter().map(|arr| arr.validity())
#[allow(clippy::type_complexity)]
pub fn iter_validities(&self) -> Map<Iter<'_, ArrayRef>, fn(&ArrayRef) -> Option<&Bitmap>> {
fn to_validity(arr: &ArrayRef) -> Option<&Bitmap> {
arr.validity()
}
self.chunks.iter().map(to_validity)
}

#[inline]
Expand Down Expand Up @@ -763,16 +765,6 @@ pub(crate) mod test {
assert_eq!(Vec::from(&s.reverse()), &[Some("c"), None, Some("a")]);
}

#[test]
fn test_null_sized_chunks() {
let mut s = Float64Chunked::new("s", &Vec::<f64>::new());
s.append(&Float64Chunked::new("s2", &[1., 2., 3.]));
dbg!(&s);

let s = Float64Chunked::new("s", &Vec::<f64>::new());
dbg!(&s.into_iter().next());
}

#[test]
#[cfg(feature = "dtype-categorical")]
fn test_iter_categorical() {
Expand Down
47 changes: 41 additions & 6 deletions polars/polars-core/src/chunked_array/ops/aggregate.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
use crate::chunked_array::builder::get_list_builder;
use crate::chunked_array::ChunkedArray;
use crate::datatypes::BooleanChunked;
use crate::series::IsSorted;
use crate::{datatypes::PolarsNumericType, prelude::*, utils::CustomIterTools};
use arrow::compute;
use arrow::types::simd::Simd;
Expand Down Expand Up @@ -68,15 +69,49 @@ where
}

fn min(&self) -> Option<T::Native> {
self.downcast_iter()
.filter_map(compute::aggregate::min_primitive)
.fold_first_(|acc, v| if acc < v { acc } else { v })
match self.is_sorted2() {
IsSorted::Ascending => {
self.first_non_null().and_then(|idx| {
// Safety:
// first_non_null returns in bound index
unsafe { self.get_unchecked(idx) }
})
}
IsSorted::Descending => {
self.last_non_null().and_then(|idx| {
// Safety:
// last returns in bound index
unsafe { self.get_unchecked(idx) }
})
}
IsSorted::Not => self
.downcast_iter()
.filter_map(compute::aggregate::min_primitive)
.fold_first_(|acc, v| if acc < v { acc } else { v }),
}
}

fn max(&self) -> Option<T::Native> {
self.downcast_iter()
.filter_map(compute::aggregate::max_primitive)
.fold_first_(|acc, v| if acc > v { acc } else { v })
match self.is_sorted2() {
IsSorted::Ascending => {
self.last_non_null().and_then(|idx| {
// Safety:
// first_non_null returns in bound index
unsafe { self.get_unchecked(idx) }
})
}
IsSorted::Descending => {
self.first_non_null().and_then(|idx| {
// Safety:
// last returns in bound index
unsafe { self.get_unchecked(idx) }
})
}
IsSorted::Not => self
.downcast_iter()
.filter_map(compute::aggregate::max_primitive)
.fold_first_(|acc, v| if acc > v { acc } else { v }),
}
}

fn mean(&self) -> Option<f64> {
Expand Down
2 changes: 2 additions & 0 deletions polars/polars-core/src/chunked_array/ops/take/take_single.rs
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ macro_rules! impl_take_random_get {
macro_rules! impl_take_random_get_unchecked {
($self:ident, $index:ident, $array_type:ty) => {{
let (chunk_idx, idx) = $self.index_to_chunked_index($index);
debug_assert!(chunk_idx < $self.chunks.len());
// Safety:
// bounds are checked above
let arr = $self.chunks.get_unchecked(chunk_idx);
Expand All @@ -40,6 +41,7 @@ macro_rules! impl_take_random_get_unchecked {

// Safety:
// index should be in bounds
debug_assert!(idx < arr.len());
if arr.is_valid_unchecked(idx) {
Some(arr.value_unchecked(idx))
} else {
Expand Down
35 changes: 22 additions & 13 deletions polars/polars-core/src/frame/groupby/aggregations/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ use crate::frame::groupby::GroupsIdx;
use crate::frame::groupby::GroupsIndicator;
use crate::prelude::*;
use crate::series::implementations::SeriesWrap;
use crate::series::IsSorted;
use polars_arrow::kernels::rolling;
use polars_arrow::kernels::take_agg::*;
use polars_arrow::prelude::QuantileInterpolOptions;
Expand Down Expand Up @@ -368,6 +369,17 @@ where
ChunkedArray<T>: IntoSeries,
{
pub(crate) unsafe fn agg_min(&self, groups: &GroupsProxy) -> Series {
// faster paths
match (self.is_sorted2(), self.null_count()) {
(IsSorted::Ascending, 0) => {
return self.clone().into_series().agg_first(groups);
}
(IsSorted::Descending, 0) => {
return self.clone().into_series().agg_last(groups);
}
_ => {}
}

match groups {
GroupsProxy::Idx(groups) => agg_helper_idx::<T, _>(groups, |(first, idx)| {
debug_assert!(idx.len() <= self.len());
Expand Down Expand Up @@ -401,12 +413,6 @@ where
groups: groups_slice,
rolling,
} => {
if self.is_sorted() {
return self.clone().into_series().agg_first(groups);
}
if self.is_sorted_reverse() {
return self.clone().into_series().agg_last(groups);
}
if use_rolling_kernels(groups_slice, self.chunks()) {
let arr = self.downcast_iter().next().unwrap();
let values = arr.values().as_slice();
Expand Down Expand Up @@ -447,6 +453,16 @@ where
}

pub(crate) unsafe fn agg_max(&self, groups: &GroupsProxy) -> Series {
// faster paths
match (self.is_sorted2(), self.null_count()) {
(IsSorted::Ascending, 0) => {
return self.clone().into_series().agg_last(groups);
}
(IsSorted::Descending, 0) => {
return self.clone().into_series().agg_first(groups);
}
_ => {}
}
match groups {
GroupsProxy::Idx(groups) => agg_helper_idx::<T, _>(groups, |(first, idx)| {
debug_assert!(idx.len() <= self.len());
Expand Down Expand Up @@ -482,13 +498,6 @@ where
groups: groups_slice,
rolling,
} => {
if self.is_sorted() {
return self.clone().into_series().agg_last(groups);
}
if self.is_sorted_reverse() {
return self.clone().into_series().agg_first(groups);
}

if use_rolling_kernels(groups_slice, self.chunks()) {
let arr = self.downcast_iter().next().unwrap();
let values = arr.values().as_slice();
Expand Down
45 changes: 45 additions & 0 deletions polars/polars-core/src/utils/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ pub(crate) mod series;
use crate::prelude::*;
use crate::POOL;
pub use arrow;
use arrow::bitmap::Bitmap;
pub use polars_arrow::utils::TrustMyLength;
pub use polars_arrow::utils::*;
pub use rayon;
Expand Down Expand Up @@ -955,6 +956,50 @@ pub(crate) fn create_chunked_index_mapping(chunks: &[ArrayRef], len: usize) -> V
vals
}

pub(crate) fn first_non_null<'a, I>(iter: I) -> Option<usize>
where
I: Iterator<Item = Option<&'a Bitmap>>,
{
let mut offset = 0;
for validity in iter {
if let Some(validity) = validity {
for (idx, is_valid) in validity.iter().enumerate() {
if is_valid {
return Some(offset + idx);
}
}
offset += validity.len()
} else {
return Some(offset);
}
}
None
}

pub(crate) fn last_non_null<'a, I>(iter: I, len: usize) -> Option<usize>
where
I: DoubleEndedIterator<Item = Option<&'a Bitmap>>,
{
if len == 0 {
return None;
}
let mut offset = 0;
let len = len - 1;
for validity in iter.rev() {
if let Some(validity) = validity {
for (idx, is_valid) in validity.iter().rev().enumerate() {
if is_valid {
return Some(len - (offset + idx));
}
}
offset += validity.len()
} else {
return Some(len - offset);
}
}
None
}

#[cfg(test)]
mod test {
use super::*;
Expand Down
21 changes: 19 additions & 2 deletions polars/polars-lazy/src/physical_plan/planner/expr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ use super::super::expressions as phys_expr;
use crate::prelude::*;
use polars_core::frame::groupby::GroupByMethod;
use polars_core::prelude::*;
use polars_core::series::IsSorted;
use polars_core::utils::parallel_op_series;

impl DefaultPlanner {
Expand Down Expand Up @@ -161,7 +162,15 @@ impl DefaultPlanner {
Context::Default => {
let function = SpecialEq::new(Arc::new(move |s: &mut [Series]| {
let s = std::mem::take(&mut s[0]);
parallel_op_series(|s| Ok(s.min_as_series()), s, None)

match s.is_sorted() {
IsSorted::Ascending | IsSorted::Descending => {
Ok(s.min_as_series())
}
IsSorted::Not => {
parallel_op_series(|s| Ok(s.min_as_series()), s, None)
}
}
})
as Arc<dyn SeriesUdf>);
Ok(Arc::new(ApplyExpr {
Expand All @@ -183,7 +192,15 @@ impl DefaultPlanner {
Context::Default => {
let function = SpecialEq::new(Arc::new(move |s: &mut [Series]| {
let s = std::mem::take(&mut s[0]);
parallel_op_series(|s| Ok(s.max_as_series()), s, None)

match s.is_sorted() {
IsSorted::Ascending | IsSorted::Descending => {
Ok(s.max_as_series())
}
IsSorted::Not => {
parallel_op_series(|s| Ok(s.max_as_series()), s, None)
}
}
})
as Arc<dyn SeriesUdf>);
Ok(Arc::new(ApplyExpr {
Expand Down
48 changes: 48 additions & 0 deletions py-polars/tests/test_sort.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,3 +131,51 @@ def test_sort_by_exps_nulls_last() -> None:
"row_nr": [0, 4, 2, 1, 3],
"a": [1, 1, -2, 3, None],
}


def test_sort_aggregation_fast_paths() -> None:
df = pl.DataFrame(
{
"a": [None, 3, 2, 1],
"b": [3, 2, 1, None],
"c": [3, None, None, None],
"e": [None, None, None, 1],
"f": [1, 2, 5, 1],
}
)

expected = df.select(
[
pl.all().max().suffix("_max"),
pl.all().min().suffix("_min"),
]
)

assert expected.to_dict(False) == {
"a_max": [3],
"b_max": [3],
"c_max": [3],
"e_max": [1],
"f_max": [5],
"a_min": [1],
"b_min": [1],
"c_min": [3],
"e_min": [1],
"f_min": [1],
}

for reverse in [True, False]:
for null_last in [True, False]:
out = df.select(
[
pl.all()
.sort(reverse=reverse, nulls_last=null_last)
.max()
.suffix("_max"),
pl.all()
.sort(reverse=reverse, nulls_last=null_last)
.min()
.suffix("_min"),
]
)
assert out.frame_equal(expected)

0 comments on commit 055d1d3

Please sign in to comment.