Skip to content

Commit

Permalink
Improve rolling_sum/rolling_mean for windows with null values. (#3466)
Browse files Browse the repository at this point in the history
* improve rolling_sum nulls

* refactor rolling kernel modules

* improve rolling mean nulls
  • Loading branch information
ritchie46 committed May 22, 2022
1 parent edc4dbd commit 5c128e3
Show file tree
Hide file tree
Showing 11 changed files with 274 additions and 119 deletions.
6 changes: 0 additions & 6 deletions polars/polars-arrow/src/kernels/rolling/mod.rs
Original file line number Diff line number Diff line change
@@ -1,11 +1,5 @@
mod mean_no_nulls;
mod min_max_no_nulls;
mod min_max_nulls;
pub mod no_nulls;
pub mod nulls;
mod quantile_no_nulls;
mod quantile_nulls;
mod sum_no_nulls;
mod window;

use crate::data_types::IsFloat;
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
use super::sum::SumWindow;
use super::*;
use crate::kernels::rolling::sum_no_nulls::SumWindow;
use no_nulls::{rolling_apply_agg_window, RollingAggWindow};

struct MeanWindow<'a, T> {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ impl<'a, T: NativeType + IsFloat + PartialOrd> RollingAggWindow<'a, T> for MinWi
}
self.last_start = start;

// we traverese all values and compute
// we traverse all values and compute
if recompute_min {
self.min = *self
.slice
Expand Down
Original file line number Diff line number Diff line change
@@ -1,3 +1,8 @@
mod mean;
mod min_max;
mod quantile;
mod sum;

use super::*;
use crate::utils::CustomIterTools;
use arrow::array::{ArrayRef, PrimitiveArray};
Expand All @@ -9,10 +14,10 @@ use serde::{Deserialize, Serialize};
use std::fmt::Debug;
use std::sync::Arc;

pub use mean_no_nulls::rolling_mean;
pub use min_max_no_nulls::{rolling_max, rolling_min};
pub use quantile_no_nulls::{rolling_median, rolling_quantile};
pub use sum_no_nulls::rolling_sum;
pub use mean::rolling_mean;
pub use min_max::{rolling_max, rolling_min};
pub use quantile::{rolling_median, rolling_quantile};
pub use sum::rolling_sum;

pub(crate) trait RollingAggWindow<'a, T: NativeType> {
fn new(slice: &'a [T], start: usize, end: usize) -> Self;
Expand Down Expand Up @@ -185,13 +190,6 @@ where
values.iter().zip(weights).map(|(v, w)| *v * *w).sum::<T>() / T::from(values.len()).unwrap()
}

pub(crate) fn compute_sum<T>(values: &[T]) -> T
where
T: std::iter::Sum<T> + Copy,
{
values.iter().copied().sum()
}

pub(crate) fn compute_sum_weights<T>(values: &[T], weights: &[T]) -> T
where
T: std::iter::Sum<T> + Copy + std::ops::Mul<Output = T>,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -242,9 +242,7 @@ where
#[cfg(test)]
mod test {
use super::*;
use crate::kernels::rolling::min_max_no_nulls::{rolling_max, rolling_min};
use arrow::buffer::Buffer;
use arrow::datatypes::DataType;
use crate::kernels::rolling::no_nulls::{rolling_max, rolling_min};

#[test]
fn test_rolling_median() {
Expand Down
75 changes: 75 additions & 0 deletions polars/polars-arrow/src/kernels/rolling/nulls/mean.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
use super::sum::SumWindow;
use super::*;
use super::{rolling_apply_agg_window, RollingAggWindow};

struct MeanWindow<'a, T> {
sum: SumWindow<'a, T>,
}

impl<
'a,
T: NativeType
+ IsFloat
+ PartialOrd
+ Add<Output = T>
+ Sub<Output = T>
+ NumCast
+ Div<Output = T>,
> RollingAggWindow<'a, T> for MeanWindow<'a, T>
{
unsafe fn new(
slice: &'a [T],
validity: &'a Bitmap,
start: usize,
end: usize,
min_periods: usize,
) -> Self {
Self {
sum: SumWindow::new(slice, validity, start, end, min_periods),
}
}

unsafe fn update(&mut self, start: usize, end: usize) -> Option<T> {
let sum = self.sum.update(start, end);
dbg!(sum);
sum.map(|sum| sum / NumCast::from(end - start - self.sum.null_count).unwrap())
}
}

pub fn rolling_mean<T>(
arr: &PrimitiveArray<T>,
window_size: usize,
min_periods: usize,
center: bool,
weights: Option<&[f64]>,
) -> ArrayRef
where
T: NativeType
+ IsFloat
+ PartialOrd
+ Add<Output = T>
+ Sub<Output = T>
+ NumCast
+ Div<Output = T>,
{
if weights.is_some() {
panic!("weights not yet supported on array with null values")
}
if center {
rolling_apply_agg_window::<MeanWindow<_>, _, _>(
arr.values().as_slice(),
arr.validity().as_ref().unwrap(),
window_size,
min_periods,
det_offsets_center,
)
} else {
rolling_apply_agg_window::<MeanWindow<_>, _, _>(
arr.values().as_slice(),
arr.validity().as_ref().unwrap(),
window_size,
min_periods,
det_offsets,
)
}
}
Original file line number Diff line number Diff line change
@@ -1,6 +1,13 @@
mod mean;
mod min_max;
mod quantile;
mod sum;

use super::*;
pub use min_max_nulls::{rolling_max, rolling_min};
pub use quantile_nulls::{rolling_median, rolling_quantile};
pub use mean::rolling_mean;
pub use min_max::{rolling_max, rolling_min};
pub use quantile::{rolling_median, rolling_quantile};
pub use sum::rolling_sum;

pub(crate) trait RollingAggWindow<'a, T: NativeType> {
unsafe fn new(
Expand Down Expand Up @@ -117,33 +124,6 @@ where
))
}

fn compute_sum<T>(
values: &[T],
validity_bytes: &[u8],
offset: usize,
min_periods: usize,
) -> Option<T>
where
T: NativeType + std::iter::Sum<T> + Zero + AddAssign,
{
let null_count = count_zeros(validity_bytes, offset, values.len());
if null_count == 0 {
Some(no_nulls::compute_sum(values))
} else if (values.len() - null_count) < min_periods {
None
} else {
let mut out = Zero::zero();
for (i, val) in values.iter().enumerate() {
// Safety:
// in bounds
if unsafe { get_bit_unchecked(validity_bytes, offset + i) } {
out += *val;
}
}
Some(out)
}
}

fn compute_mean<T>(
values: &[T],
validity_bytes: &[u8],
Expand Down Expand Up @@ -242,77 +222,10 @@ where
}
}

pub fn rolling_sum<T>(
arr: &PrimitiveArray<T>,
window_size: usize,
min_periods: usize,
center: bool,
weights: Option<&[f64]>,
) -> ArrayRef
where
T: NativeType + std::iter::Sum + Zero + AddAssign + Copy,
{
if weights.is_some() {
panic!("weights not yet supported on array with null values")
}
if center {
rolling_apply(
arr.values().as_slice(),
arr.validity().as_ref().unwrap(),
window_size,
min_periods,
det_offsets_center,
compute_sum,
)
} else {
rolling_apply(
arr.values().as_slice(),
arr.validity().as_ref().unwrap(),
window_size,
min_periods,
det_offsets,
compute_sum,
)
}
}

pub fn rolling_mean<T>(
arr: &PrimitiveArray<T>,
window_size: usize,
min_periods: usize,
center: bool,
weights: Option<&[f64]>,
) -> ArrayRef
where
T: NativeType + std::iter::Sum + Zero + AddAssign + Copy + Float,
{
if weights.is_some() {
panic!("weights not yet supported on array with null values")
}
if center {
rolling_apply(
arr.values().as_slice(),
arr.validity().as_ref().unwrap(),
window_size,
min_periods,
det_offsets_center,
compute_mean,
)
} else {
rolling_apply(
arr.values().as_slice(),
arr.validity().as_ref().unwrap(),
window_size,
min_periods,
det_offsets,
compute_mean,
)
}
}

#[cfg(test)]
mod test {
use super::*;
use crate::kernels::rolling::nulls::mean::rolling_mean;
use arrow::buffer::Buffer;
use arrow::datatypes::DataType;

Expand Down Expand Up @@ -351,6 +264,32 @@ mod test {
assert_eq!(out, &[None, None, None, None]);
}

#[test]
fn test_rolling_mean_nulls() {
// 1, None, -1, 4
let buf = Buffer::from(vec![1.0, 0.0, -1.0, 4.0]);
let arr = &PrimitiveArray::from_data(
DataType::Float64,
buf,
Some(Bitmap::from(&[true, false, true, true])),
);

let out = rolling_mean(arr, 2, 2, false, None);
let out = out.as_any().downcast_ref::<PrimitiveArray<f64>>().unwrap();
let out = out.into_iter().map(|v| v.copied()).collect::<Vec<_>>();
assert_eq!(out, &[None, None, None, Some(1.5)]);

let out = rolling_mean(arr, 2, 1, false, None);
let out = out.as_any().downcast_ref::<PrimitiveArray<f64>>().unwrap();
let out = out.into_iter().map(|v| v.copied()).collect::<Vec<_>>();
assert_eq!(out, &[Some(1.0), Some(1.0), Some(-1.0), Some(1.5)]);

let out = rolling_mean(arr, 4, 1, false, None);
let out = out.as_any().downcast_ref::<PrimitiveArray<f64>>().unwrap();
let out = out.into_iter().map(|v| v.copied()).collect::<Vec<_>>();
assert_eq!(out, &[Some(1.0), Some(1.0), Some(0.0), Some(4.0 / 3.0)]);
}

#[test]
fn test_rolling_max_no_nulls() {
let buf = Buffer::from(vec![1.0, 2.0, 3.0, 4.0]);
Expand Down

0 comments on commit 5c128e3

Please sign in to comment.