Skip to content

Commit

Permalink
Fix rolling_max (#2012)
Browse files Browse the repository at this point in the history
  • Loading branch information
mhconradt committed Dec 8, 2021
1 parent 090bb7e commit c233c2c
Show file tree
Hide file tree
Showing 2 changed files with 28 additions and 2 deletions.
3 changes: 2 additions & 1 deletion polars/polars-arrow/src/kernels/rolling/no_nulls.rs
Original file line number Diff line number Diff line change
Expand Up @@ -114,7 +114,7 @@ where
.unwrap()
}

fn compute_max<T>(values: &[T]) -> T
pub(crate) fn compute_max<T>(values: &[T]) -> T
where
T: NativeType + PartialOrd,
{
Expand Down Expand Up @@ -181,6 +181,7 @@ where
}
}
}

pub fn rolling_min<T>(
values: &[T],
window_size: usize,
Expand Down
27 changes: 26 additions & 1 deletion polars/polars-arrow/src/kernels/rolling/nulls.rs
Original file line number Diff line number Diff line change
Expand Up @@ -197,7 +197,7 @@ where
{
let null_count = count_zeros(validity_bytes, offset, values.len());
if null_count == 0 {
Some(no_nulls::compute_min(values))
Some(no_nulls::compute_max(values))
} else if (values.len() - null_count) < min_periods {
None
} else {
Expand All @@ -221,6 +221,7 @@ where
out
}
}

pub fn rolling_var<T>(
arr: &PrimitiveArray<T>,
window_size: usize,
Expand Down Expand Up @@ -431,4 +432,28 @@ mod test {
let out = out.into_iter().map(|v| v.copied()).collect::<Vec<_>>();
assert_eq!(out, &[None, None, None, None]);
}

#[test]
fn test_rolling_max_no_nulls() {
let buf = Buffer::from([1.0, 2.0, 3.0, 4.0]);
let arr = &PrimitiveArray::from_data(
DataType::Float64,
buf,
Some(Bitmap::from(&[true, true, true, true])),
);
let out = rolling_max(arr, 4, 1, false, None);
let out = out.as_any().downcast_ref::<PrimitiveArray<f64>>().unwrap();
let out = out.into_iter().map(|v| v.copied()).collect::<Vec<_>>();
assert_eq!(out, &[Some(1.0), Some(2.0), Some(3.0), Some(4.0)]);

let out = rolling_max(arr, 2, 2, false, None);
let out = out.as_any().downcast_ref::<PrimitiveArray<f64>>().unwrap();
let out = out.into_iter().map(|v| v.copied()).collect::<Vec<_>>();
assert_eq!(out, &[None, Some(2.0), Some(3.0), Some(4.0)]);

let out = rolling_max(arr, 4, 4, false, None);
let out = out.as_any().downcast_ref::<PrimitiveArray<f64>>().unwrap();
let out = out.into_iter().map(|v| v.copied()).collect::<Vec<_>>();
assert_eq!(out, &[None, None, None, Some(4.0)])
}
}

0 comments on commit c233c2c

Please sign in to comment.