Skip to content

Commit

Permalink
fix[rust]: fix rolling_min/max for arrays with null values (#4446)
Browse files Browse the repository at this point in the history
  • Loading branch information
ritchie46 committed Aug 16, 2022
1 parent c171634 commit 39fcfc5
Show file tree
Hide file tree
Showing 4 changed files with 94 additions and 50 deletions.
1 change: 1 addition & 0 deletions polars/polars-arrow/src/compute/mod.rs
Original file line number Diff line number Diff line change
@@ -1,2 +1,3 @@
#[cfg(feature = "compute")]
pub mod cast;
pub mod take;
115 changes: 66 additions & 49 deletions polars/polars-arrow/src/kernels/rolling/nulls/min_max.rs
Original file line number Diff line number Diff line change
Expand Up @@ -85,73 +85,74 @@ impl<'a, T: NativeType> RollingAggWindowNulls<'a, T> for SortedMinMax<'a, T> {
}
}

/// Generic `Min` / `Max` kernel. It is written in terms of `Min` aggregation,
/// but applies to `max` as well, just mentally `:s/min/max/g`.
/// Generic `Min` / `Max` kernel.
pub struct MinMaxWindow<'a, T: NativeType + PartialOrd + IsFloat> {
slice: &'a [T],
validity: &'a Bitmap,
min: Option<T>,
extremum: Option<T>,
last_start: usize,
last_end: usize,
null_count: usize,
compare_fn_nan: fn(&T, &T) -> Ordering,
take_extremum: fn(T, T) -> T,
// ordering on which the window needs to act.
// for min kernel this is Less
// for max kernel this is Greater
agg_ordering: Ordering,
}

impl<'a, T: NativeType + IsFloat + PartialOrd> MinMaxWindow<'a, T> {
unsafe fn compute_min_in_between_leaving_and_entering(&self, start: usize) -> Option<T> {
unsafe fn compute_extremum_in_between_leaving_and_entering(&self, start: usize) -> Option<T> {
// check the values in between the window that remains e.g. is not leaving
// this between `start..last_end`
//
// because we know the current `min` (which might be leaving), we know we can stop
// searching if any value is equal to current `min`.
let mut min_in_between = None;
let mut extremum_in_between = None;
for idx in start..self.last_end {
let valid = self.validity.get_bit_unchecked(idx);
let value = self.slice.get_unchecked(idx);

if valid {
// early return
if let Some(current_min) = self.min {
if let Some(current_min) = self.extremum {
if matches!(compare_fn_nan_min(value, &current_min), Ordering::Equal) {
return Some(current_min);
}
}

match min_in_between {
None => min_in_between = Some(*value),
match extremum_in_between {
None => extremum_in_between = Some(*value),
Some(current) => {
min_in_between =
Some(std::cmp::min_by(*value, current, self.compare_fn_nan))
extremum_in_between = Some((self.take_extremum)(*value, current))
}
}
}
}
min_in_between
extremum_in_between
}

// compute min from the entire window
unsafe fn compute_min_and_update_null_count(&mut self, start: usize, end: usize) -> Option<T> {
let mut min = None;
unsafe fn compute_extremum_and_update_null_count(
&mut self,
start: usize,
end: usize,
) -> Option<T> {
let mut extremum = None;
let mut idx = start;
for value in &self.slice[start..end] {
let valid = self.validity.get_bit_unchecked(idx);
if valid {
match min {
None => min = Some(*value),
Some(current) => {
min = Some(std::cmp::min_by(*value, current, self.compare_fn_nan))
}
match extremum {
None => extremum = Some(*value),
Some(current) => extremum = Some((self.take_extremum)(*value, current)),
}
} else {
self.null_count += 1;
}
idx += 1;
}
min
extremum
}

unsafe fn new(
Expand All @@ -160,34 +161,36 @@ impl<'a, T: NativeType + IsFloat + PartialOrd> MinMaxWindow<'a, T> {
start: usize,
end: usize,
compare_fn: fn(&T, &T) -> Ordering,
take_extremum: fn(T, T) -> T,
agg_ordering: Ordering,
) -> Self {
let mut out = Self {
slice,
validity,
min: None,
extremum: None,
last_start: start,
last_end: end,
null_count: 0,
compare_fn_nan: compare_fn,
take_extremum,
agg_ordering,
};
let min = out.compute_min_and_update_null_count(start, end);
out.min = min;
let extremum = out.compute_extremum_and_update_null_count(start, end);
out.extremum = extremum;
out
}

unsafe fn update(&mut self, start: usize, end: usize) -> Option<T> {
// recompute min
if start >= self.last_end {
self.min = self.compute_min_and_update_null_count(start, end);
self.extremum = self.compute_extremum_and_update_null_count(start, end);
self.last_end = end;
self.last_start = start;
return self.min;
return self.extremum;
}

// remove elements that should leave the window
let mut recompute_min = false;
let mut recompute_extremum = false;
for idx in self.last_start..start {
// safety
// we are in bounds
Expand All @@ -198,10 +201,10 @@ impl<'a, T: NativeType + IsFloat + PartialOrd> MinMaxWindow<'a, T> {
// if the leaving value is the
// min value, we need to recompute the min.
if matches!(
(self.compare_fn_nan)(leaving_value, &self.min.unwrap()),
(self.compare_fn_nan)(leaving_value, &self.extremum.unwrap()),
Ordering::Equal
) {
recompute_min = true;
recompute_extremum = true;
break;
}
} else {
Expand All @@ -210,68 +213,72 @@ impl<'a, T: NativeType + IsFloat + PartialOrd> MinMaxWindow<'a, T> {

// self.min is None and the leaving value is None
// if the entering value is valid, we might get a new min.
if self.min.is_none() {
recompute_min = true;
if self.extremum.is_none() {
recompute_extremum = true;
break;
}
}
}

let entering_min = self.compute_min_and_update_null_count(self.last_end, end);
let entering_extremum = self.compute_extremum_and_update_null_count(self.last_end, end);

match (self.min, entering_min) {
match (self.extremum, entering_extremum) {
// all remains `None`
(None, None) => {}
(None, Some(new_min)) => self.min = Some(new_min),
(None, Some(new_min)) => self.extremum = Some(new_min),
// entering min is `None` and the `min` is leaving, so the `in_between` min is the new
// minimum.
// if min is not leaving, we don't do anything
(Some(_current_min), None) => {
if recompute_min {
self.min = self.compute_min_in_between_leaving_and_entering(start);
if recompute_extremum {
self.extremum = self.compute_extremum_in_between_leaving_and_entering(start);
}
}
(Some(current_min), Some(entering_min)) => {
if recompute_min {
match (self.compare_fn_nan)(&current_min, &entering_min) {
(Some(current_extremum), Some(entering_extremum)) => {
if recompute_extremum {
match (self.compare_fn_nan)(&current_extremum, &entering_extremum) {
// do nothing
Ordering::Equal => {}
// leaving < entering
ord if ord == self.agg_ordering => {
// leaving value could be the smallest, we might need to recompute

let min_in_between =
self.compute_min_in_between_leaving_and_entering(start);
self.compute_extremum_in_between_leaving_and_entering(start);
match min_in_between {
None => self.min = Some(entering_min),
Some(min_in_between) => {
if (self.compare_fn_nan)(&min_in_between, &entering_min)
== self.agg_ordering
None => self.extremum = Some(entering_extremum),
Some(extremum_in_between) => {
if (self.compare_fn_nan)(
&extremum_in_between,
&entering_extremum,
) == self.agg_ordering
{
self.min = Some(min_in_between)
self.extremum = Some(extremum_in_between)
} else {
self.min = Some(entering_min)
self.extremum = Some(entering_extremum)
}
}
}
}
// leaving > entering
_ => {
if (self.compare_fn_nan)(&entering_min, &current_min)
if (self.compare_fn_nan)(&entering_extremum, &current_extremum)
== self.agg_ordering
{
self.min = Some(entering_min)
self.extremum = Some(entering_extremum)
}
}
}
} else if (self.compare_fn_nan)(&entering_min, &current_min) == self.agg_ordering {
self.min = Some(entering_min)
} else if (self.compare_fn_nan)(&entering_extremum, &current_extremum)
== self.agg_ordering
{
self.extremum = Some(entering_extremum)
}
}
}
self.last_start = start;
self.last_end = end;
self.min
self.extremum
}

fn is_valid(&self, min_periods: usize) -> bool {
Expand All @@ -283,6 +290,10 @@ pub struct MinWindow<'a, T: NativeType + PartialOrd + IsFloat> {
inner: MinMaxWindow<'a, T>,
}

fn take_min<T: NativeType + IsFloat + PartialOrd>(a: T, b: T) -> T {
std::cmp::min_by(a, b, compare_fn_nan_min)
}

impl<'a, T: NativeType + IsFloat + PartialOrd> RollingAggWindowNulls<'a, T> for MinWindow<'a, T> {
unsafe fn new(slice: &'a [T], validity: &'a Bitmap, start: usize, end: usize) -> Self {
Self {
Expand All @@ -292,6 +303,7 @@ impl<'a, T: NativeType + IsFloat + PartialOrd> RollingAggWindowNulls<'a, T> for
start,
end,
compare_fn_nan_min,
take_min,
Ordering::Less,
),
}
Expand Down Expand Up @@ -342,6 +354,10 @@ pub struct MaxWindow<'a, T: NativeType + PartialOrd + IsFloat> {
inner: MinMaxWindow<'a, T>,
}

fn take_max<T: NativeType + IsFloat + PartialOrd>(a: T, b: T) -> T {
std::cmp::max_by(a, b, compare_fn_nan_max)
}

impl<'a, T: NativeType + IsFloat + PartialOrd> RollingAggWindowNulls<'a, T> for MaxWindow<'a, T> {
unsafe fn new(slice: &'a [T], validity: &'a Bitmap, start: usize, end: usize) -> Self {
Self {
Expand All @@ -351,6 +367,7 @@ impl<'a, T: NativeType + IsFloat + PartialOrd> RollingAggWindowNulls<'a, T> for
start,
end,
compare_fn_nan_max,
take_max,
Ordering::Greater,
),
}
Expand Down
27 changes: 27 additions & 0 deletions polars/polars-arrow/src/kernels/rolling/nulls/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,7 @@ where

#[cfg(test)]
mod test {
use arrow::array::{Array, Int32Array};
use arrow::buffer::Buffer;
use arrow::datatypes::DataType;

Expand Down Expand Up @@ -220,4 +221,30 @@ mod test {
let out = out.into_iter().map(|v| v.copied()).collect::<Vec<_>>();
assert_eq!(out, &[Some(4.0), Some(4.0), Some(3.0), Some(2.0)]);
}

#[test]
fn test_rolling_extrema_nulls() {
let vals = vec![3, 3, 3, 10, 10, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1];
let mut validity = MutableBitmap::new();
validity.extend_constant(vals.len(), true);

let window_size = 3;
let min_periods = 3;

let arr = Int32Array::new(DataType::Int32, vals.into(), Some(validity.into()));

let out = rolling_apply_agg_window::<MaxWindow<_>, _, _>(
arr.values().as_slice(),
arr.validity().as_ref().unwrap(),
window_size,
min_periods,
det_offsets,
);
let arr = out.as_any().downcast_ref::<Int32Array>().unwrap();
assert_eq!(arr.null_count(), 2);
assert_eq!(
&arr.values().as_slice()[2..],
&[3, 10, 10, 10, 10, 10, 9, 8, 7, 6, 5, 4, 3]
);
}
}
1 change: 0 additions & 1 deletion polars/polars-arrow/src/lib.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
pub mod array;
pub mod bit_util;
mod bitmap;
#[cfg(feature = "compute")]
pub mod compute;
pub mod conversion;
pub mod data_types;
Expand Down

0 comments on commit 39fcfc5

Please sign in to comment.