Skip to content

Commit

Permalink
improve rolling_min/max for columns with null values (#3458)
Browse files Browse the repository at this point in the history
  • Loading branch information
ritchie46 committed May 21, 2022
1 parent 2243ecf commit 6565276
Show file tree
Hide file tree
Showing 8 changed files with 553 additions and 344 deletions.
2 changes: 1 addition & 1 deletion polars/polars-arrow/src/kernels/rolling/mean_no_nulls.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
use super::*;
use crate::kernels::rolling::sum_min_max_no_nulls::SumWindow;
use crate::kernels::rolling::sum_no_nulls::SumWindow;
use no_nulls::{rolling_apply_agg_window, RollingAggWindow};

struct MeanWindow<'a, T> {
Expand Down
Original file line number Diff line number Diff line change
@@ -1,65 +1,6 @@
use super::*;
use no_nulls;
use no_nulls::{rolling_apply_agg_window, RollingAggWindow};
use std::ops::SubAssign;

pub(super) struct SumWindow<'a, T> {
slice: &'a [T],
sum: T,
last_start: usize,
last_end: usize,
}

impl<'a, T: NativeType + IsFloat + std::iter::Sum + AddAssign + SubAssign> RollingAggWindow<'a, T>
for SumWindow<'a, T>
{
fn new(slice: &'a [T], start: usize, end: usize) -> Self {
let sum = slice[start..end].iter().copied().sum::<T>();
Self {
slice,
sum,
last_start: start,
last_end: end,
}
}

unsafe fn update(&mut self, start: usize, end: usize) -> T {
// remove elements that should leave the window
let mut recompute_sum = false;
for idx in self.last_start..start {
// safety
// we are in bounds
let leaving_value = self.slice.get_unchecked(idx);

if T::is_float() && leaving_value.is_nan() {
recompute_sum = true;
break;
}

self.sum -= *leaving_value;
}
self.last_start = start;

// we traverese all values and compute
if T::is_float() && recompute_sum {
self.sum = self
.slice
.get_unchecked(start..end)
.iter()
.copied()
.sum::<T>();
}
// the max has not left the window, so we only check
// if the entering values are larger
else {
for idx in self.last_end..end {
self.sum += *self.slice.get_unchecked(idx);
}
}
self.last_end = end;
self.sum
}
}

struct MinWindow<'a, T: NativeType + PartialOrd + IsFloat> {
slice: &'a [T],
Expand Down Expand Up @@ -198,23 +139,6 @@ impl<'a, T: NativeType + IsFloat + PartialOrd> RollingAggWindow<'a, T> for MaxWi
}
}

pub(crate) fn compute_min<T>(values: &[T]) -> T
where
T: NativeType + PartialOrd + IsFloat + Bounded,
{
let mut min = T::max_value();

for &v in values {
if T::is_float() && v.is_nan() {
return v;
}
if v < min {
min = v
}
}
min
}

pub(crate) fn compute_min_weights<T>(values: &[T], weights: &[T]) -> T
where
T: NativeType + PartialOrd + std::ops::Mul<Output = T>,
Expand All @@ -227,23 +151,6 @@ where
.unwrap()
}

pub(crate) fn compute_max<T>(values: &[T]) -> T
where
T: NativeType + PartialOrd + IsFloat + Bounded,
{
let mut max = T::min_value();

for &v in values {
if T::is_float() && v.is_nan() {
return v;
}
if v > max {
max = v
}
}
max
}

pub(crate) fn compute_max_weights<T>(values: &[T], weights: &[T]) -> T
where
T: NativeType + PartialOrd + IsFloat + Bounded + Mul<Output = T>,
Expand Down Expand Up @@ -278,9 +185,12 @@ where
min_periods,
det_offsets_center,
),
(false, None) => {
no_nulls::rolling_apply(values, window_size, min_periods, det_offsets, compute_max)
}
(false, None) => rolling_apply_agg_window::<MaxWindow<_>, _, _>(
values,
window_size,
min_periods,
det_offsets,
),
(true, Some(weights)) => {
assert!(
T::is_float(),
Expand Down Expand Up @@ -382,54 +292,6 @@ where
}
}

pub fn rolling_sum<T>(
values: &[T],
window_size: usize,
min_periods: usize,
center: bool,
weights: Option<&[f64]>,
) -> ArrayRef
where
T: NativeType + std::iter::Sum + NumCast + Mul<Output = T> + AddAssign + SubAssign + IsFloat,
{
match (center, weights) {
(true, None) => rolling_apply_agg_window::<SumWindow<_>, _, _>(
values,
window_size,
min_periods,
det_offsets_center,
),
(false, None) => rolling_apply_agg_window::<SumWindow<_>, _, _>(
values,
window_size,
min_periods,
det_offsets,
),
(true, Some(weights)) => {
let weights = no_nulls::coerce_weights(weights);
no_nulls::rolling_apply_weights(
values,
window_size,
min_periods,
det_offsets_center,
no_nulls::compute_sum_weights,
&weights,
)
}
(false, Some(weights)) => {
let weights = no_nulls::coerce_weights(weights);
no_nulls::rolling_apply_weights(
values,
window_size,
min_periods,
det_offsets,
no_nulls::compute_sum_weights,
&weights,
)
}
}
}

#[cfg(test)]
mod test {
use super::*;
Expand Down Expand Up @@ -502,58 +364,4 @@ mod test {
)
);
}

use super::*;

#[test]
fn test_rolling_sum() {
let values = &[1.0f64, 2.0, 3.0, 4.0];

let out = rolling_sum(values, 2, 2, false, None);
let out = out.as_any().downcast_ref::<PrimitiveArray<f64>>().unwrap();
let out = out.into_iter().map(|v| v.copied()).collect::<Vec<_>>();
assert_eq!(out, &[None, Some(3.0), Some(5.0), Some(7.0)]);

let out = rolling_sum(values, 2, 1, false, None);
let out = out.as_any().downcast_ref::<PrimitiveArray<f64>>().unwrap();
let out = out.into_iter().map(|v| v.copied()).collect::<Vec<_>>();
assert_eq!(out, &[Some(1.0), Some(3.0), Some(5.0), Some(7.0)]);

let out = rolling_sum(values, 4, 1, false, None);
let out = out.as_any().downcast_ref::<PrimitiveArray<f64>>().unwrap();
let out = out.into_iter().map(|v| v.copied()).collect::<Vec<_>>();
assert_eq!(out, &[Some(1.0), Some(3.0), Some(6.0), Some(10.0)]);

let out = rolling_sum(values, 4, 1, true, None);
let out = out.as_any().downcast_ref::<PrimitiveArray<f64>>().unwrap();
let out = out.into_iter().map(|v| v.copied()).collect::<Vec<_>>();
assert_eq!(out, &[Some(3.0), Some(6.0), Some(10.0), Some(9.0)]);

let out = rolling_sum(values, 4, 4, true, None);
let out = out.as_any().downcast_ref::<PrimitiveArray<f64>>().unwrap();
let out = out.into_iter().map(|v| v.copied()).collect::<Vec<_>>();
assert_eq!(out, &[None, None, Some(10.0), None]);

// test nan handling.
let values = &[1.0, 2.0, 3.0, f64::nan(), 5.0, 6.0, 7.0];
let out = rolling_sum(values, 3, 3, false, None);
let out = out.as_any().downcast_ref::<PrimitiveArray<f64>>().unwrap();
let out = out.into_iter().map(|v| v.copied()).collect::<Vec<_>>();

assert_eq!(
format!("{:?}", out.as_slice()),
format!(
"{:?}",
&[
None,
None,
Some(6.0),
Some(f64::nan()),
Some(f64::nan()),
Some(f64::nan()),
Some(18.0)
]
)
);
}
}

0 comments on commit 6565276

Please sign in to comment.