Skip to content

Commit

Permalink
Improve rolling min max (#3531)
Browse files Browse the repository at this point in the history
  • Loading branch information
ritchie46 committed May 30, 2022
1 parent fce5bcd commit ada3a0e
Show file tree
Hide file tree
Showing 28 changed files with 969 additions and 462 deletions.
244 changes: 162 additions & 82 deletions polars/polars-arrow/src/kernels/rolling/no_nulls/min_max.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,21 @@ use super::*;
use no_nulls;
use no_nulls::{rolling_apply_agg_window, RollingAggWindowNoNulls};

pub struct SortedMinMax<'a, T: NativeType> {
slice: &'a [T],
}

impl<'a, T: NativeType> RollingAggWindowNoNulls<'a, T> for SortedMinMax<'a, T> {
fn new(slice: &'a [T], _start: usize, _end: usize) -> Self {
Self { slice }
}

#[inline]
unsafe fn update(&mut self, start: usize, _end: usize) -> T {
*self.slice.get_unchecked(start)
}
}

pub struct MinWindow<'a, T: NativeType + PartialOrd + IsFloat> {
slice: &'a [T],
min: T,
Expand Down Expand Up @@ -74,37 +89,36 @@ impl<'a, T: NativeType + IsFloat + PartialOrd> RollingAggWindowNoNulls<'a, T> fo
Ordering::Less => {
// leaving value could be the smallest, we might need to recompute

// just a random value in the window to prevent O(n^2) behavior
// that can occur when all values in the window are the same
let remaining_value1 = self.slice.get(start).unwrap_unchecked();
let remaining_value2 = self.slice.get(end.saturating_sub(1)).unwrap();

// we check those two value in the window, if they are equal to leaving, we know
// we don't need to traverse all to compote the window
if !matches!(
compare_fn_nan_min(remaining_value1, &self.min),
Ordering::Equal
) && !matches!(
compare_fn_nan_min(remaining_value2, &self.min),
Ordering::Equal
) {
// the minimum value int the window we did not yet compute
let min_in_between = self
.slice
.get_unchecked(start..self.last_end)
.iter()
.min_by(|a, b| compare_fn_nan_min(*a, *b))
.unwrap_or(&self.slice[start]);

if matches!(
compare_fn_nan_min(min_in_between, entering_min),
Ordering::Less
) {
self.min = *min_in_between
} else {
self.min = *entering_min
// check the values in between the window we did not yet
// compute to find the max there. We compare that with the `entering_max`
// if any value is equal to equal to `self.max` of previous window we can break
// early
let mut min_in_between = self.slice.get_unchecked(start);
for idx in (start + 1)..self.last_end {
// safety
// we are in bounds
let value = self.slice.get_unchecked(idx);

if matches!(compare_fn_nan_min(value, min_in_between), Ordering::Less) {
min_in_between = value;
}

// the min is also in the in between values
if matches!(compare_fn_nan_min(value, &self.min), Ordering::Equal) {
self.last_start = start;
self.last_end = end;
return self.min;
}
}

if matches!(
compare_fn_nan_min(min_in_between, entering_min),
Ordering::Less
) {
self.min = *min_in_between
} else {
self.min = *entering_min
}
}
// leaving > entering
Ordering::Greater => {
Expand Down Expand Up @@ -201,37 +215,36 @@ impl<'a, T: NativeType + IsFloat + PartialOrd> RollingAggWindowNoNulls<'a, T> fo
Ordering::Greater => {
// leaving value could be the largest, we might need to recompute

// just a random value in the window to prevent O(n^2) behavior
// that can occur when all values in the window are the same
let remaining_value1 = self.slice.get(start).unwrap_unchecked();
let remaining_value2 = self.slice.get(end.saturating_sub(1)).unwrap();

// we check those two value in the window, if they are equal to leaving, we know
// we don't need to traverse all to compote the window
if !matches!(
compare_fn_nan_max(remaining_value1, &self.max),
Ordering::Equal
) && !matches!(
compare_fn_nan_max(remaining_value2, &self.max),
Ordering::Equal
) {
// the maximum value int the window we did not yet compute
let max_in_between = self
.slice
.get_unchecked(start..self.last_end)
.iter()
.max_by(|a, b| compare_fn_nan_max(*a, *b))
.unwrap_or(&self.slice[start]);

if matches!(
compare_fn_nan_max(max_in_between, entering_max),
Ordering::Greater
) {
self.max = *max_in_between
} else {
self.max = *entering_max
// check the values in between the window we did not yet
// compute to find the max there. We compare that with the `entering_max`
// if any value is equal to equal to `self.max` of previous window we can break
// early
let mut max_in_between = self.slice.get_unchecked(start);
for idx in (start + 1)..self.last_end {
// safety
// we are in bounds
let value = self.slice.get_unchecked(idx);

if matches!(compare_fn_nan_max(value, max_in_between), Ordering::Greater) {
max_in_between = value;
}

// the max is also in the in between values
if matches!(compare_fn_nan_max(value, &self.max), Ordering::Equal) {
self.last_start = start;
self.last_end = end;
return self.max;
}
}

if matches!(
compare_fn_nan_max(max_in_between, entering_max),
Ordering::Greater
) {
self.max = *max_in_between
} else {
self.max = *entering_max
}
}
}
} else if matches!(
Expand Down Expand Up @@ -275,6 +288,16 @@ where
max
}

pub fn is_reverse_sorted_max<T: NativeType + PartialOrd + IsFloat>(values: &[T]) -> bool {
values
.windows(2)
.all(|w| match compare_fn_nan_min(&w[0], &w[1]) {
Ordering::Equal => true,
Ordering::Greater => true,
Ordering::Less => false,
})
}

pub fn rolling_max<T>(
values: &[T],
window_size: usize,
Expand All @@ -286,18 +309,41 @@ where
T: NativeType + PartialOrd + IsFloat + Bounded + NumCast + Mul<Output = T>,
{
match (center, weights) {
(true, None) => rolling_apply_agg_window::<MaxWindow<_>, _, _>(
values,
window_size,
min_periods,
det_offsets_center,
),
(false, None) => rolling_apply_agg_window::<MaxWindow<_>, _, _>(
values,
window_size,
min_periods,
det_offsets,
),
(true, None) => {
// will be O(n2) if we don't take this path we hope that we hit an early return on not sorted data
if is_reverse_sorted_max(values) {
rolling_apply_agg_window::<SortedMinMax<_>, _, _>(
values,
window_size,
min_periods,
det_offsets_center,
)
} else {
rolling_apply_agg_window::<MaxWindow<_>, _, _>(
values,
window_size,
min_periods,
det_offsets_center,
)
}
}
(false, None) => {
if is_reverse_sorted_max(values) {
rolling_apply_agg_window::<SortedMinMax<_>, _, _>(
values,
window_size,
min_periods,
det_offsets,
)
} else {
rolling_apply_agg_window::<MaxWindow<_>, _, _>(
values,
window_size,
min_periods,
det_offsets,
)
}
}
(true, Some(weights)) => {
assert!(
T::is_float(),
Expand Down Expand Up @@ -337,6 +383,16 @@ where
}
}

pub fn is_sorted_min<T: NativeType + PartialOrd + IsFloat>(values: &[T]) -> bool {
values
.windows(2)
.all(|w| match compare_fn_nan_min(&w[0], &w[1]) {
Ordering::Equal => true,
Ordering::Less => true,
Ordering::Greater => false,
})
}

pub fn rolling_min<T>(
values: &[T],
window_size: usize,
Expand All @@ -348,18 +404,42 @@ where
T: NativeType + PartialOrd + NumCast + Mul<Output = T> + Bounded + IsFloat,
{
match (center, weights) {
(true, None) => rolling_apply_agg_window::<MinWindow<_>, _, _>(
values,
window_size,
min_periods,
det_offsets_center,
),
(false, None) => rolling_apply_agg_window::<MinWindow<_>, _, _>(
values,
window_size,
min_periods,
det_offsets,
),
(true, None) => {
// will be O(n2) if we don't take this path we hope that we hit an early return on not sorted data
if is_sorted_min(values) {
rolling_apply_agg_window::<SortedMinMax<_>, _, _>(
values,
window_size,
min_periods,
det_offsets_center,
)
} else {
rolling_apply_agg_window::<MinWindow<_>, _, _>(
values,
window_size,
min_periods,
det_offsets_center,
)
}
}
(false, None) => {
// will be O(n2)
if is_reverse_sorted_max(values) {
rolling_apply_agg_window::<SortedMinMax<_>, _, _>(
values,
window_size,
min_periods,
det_offsets_center,
)
} else {
rolling_apply_agg_window::<MinWindow<_>, _, _>(
values,
window_size,
min_periods,
det_offsets,
)
}
}
(true, Some(weights)) => {
assert!(
T::is_float(),
Expand Down
7 changes: 2 additions & 5 deletions polars/polars-arrow/src/kernels/rolling/no_nulls/variance.rs
Original file line number Diff line number Diff line change
Expand Up @@ -48,18 +48,15 @@ impl<'a, T: NativeType + IsFloat + std::iter::Sum + AddAssign + SubAssign + Mul<

self.last_start = start;

// we traverese all values and compute
// we traverse all values and compute
if T::is_float() && recompute_sum {
self.sum_of_squares = self
.slice
.get_unchecked(start..end)
.iter()
.map(|v| *v * *v)
.sum::<T>();
}
// the max has not left the window, so we only check
// if the entering values are larger
else {
} else {
for idx in self.last_end..end {
let entering_value = *self.slice.get_unchecked(idx);
self.sum_of_squares += entering_value * entering_value;
Expand Down
13 changes: 5 additions & 8 deletions polars/polars-arrow/src/kernels/rolling/nulls/mean.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,22 +11,19 @@ impl<
T: NativeType + IsFloat + Add<Output = T> + Sub<Output = T> + NumCast + Div<Output = T>,
> RollingAggWindowNulls<'a, T> for MeanWindow<'a, T>
{
unsafe fn new(
slice: &'a [T],
validity: &'a Bitmap,
start: usize,
end: usize,
min_periods: usize,
) -> Self {
unsafe fn new(slice: &'a [T], validity: &'a Bitmap, start: usize, end: usize) -> Self {
Self {
sum: SumWindow::new(slice, validity, start, end, min_periods),
sum: SumWindow::new(slice, validity, start, end),
}
}

unsafe fn update(&mut self, start: usize, end: usize) -> Option<T> {
let sum = self.sum.update(start, end);
sum.map(|sum| sum / NumCast::from(end - start - self.sum.null_count).unwrap())
}
fn is_valid(&self, min_periods: usize) -> bool {
self.sum.is_valid(min_periods)
}
}

pub fn rolling_mean<T>(
Expand Down

0 comments on commit ada3a0e

Please sign in to comment.