Skip to content

Commit

Permalink
improve rolling_var performance (#3470)
Browse files Browse the repository at this point in the history
* improve rolling_var performance no_nulls

* improve performance rolling_variance with nulls
  • Loading branch information
ritchie46 committed May 22, 2022
1 parent 69c57d4 commit f14570a
Show file tree
Hide file tree
Showing 11 changed files with 475 additions and 283 deletions.
3 changes: 1 addition & 2 deletions .github/workflows/test-windows-python.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ jobs:
with:
toolchain: nightly-2022-05-20
override: true
components: rustfmt, clippy
components: rustfmt
- name: Set up Python
uses: actions/setup-python@v3
with:
Expand All @@ -29,7 +29,6 @@ jobs:
run: |
export RUSTFLAGS="-C debuginfo=0"
cd py-polars && rustup override set nightly-2022-05-01 && make build-and-test-no-venv
cargo clippy
# test if we can import polars without any requirements
- name: Import polars
run: |
Expand Down
1 change: 0 additions & 1 deletion polars/polars-arrow/src/kernels/rolling/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@ use crate::data_types::IsFloat;
use crate::prelude::QuantileInterpolOptions;
use crate::utils::CustomIterTools;
use arrow::array::{ArrayRef, PrimitiveArray};
use arrow::bitmap::utils::{count_zeros, get_bit_unchecked};
use arrow::bitmap::{Bitmap, MutableBitmap};
use arrow::types::NativeType;
use num::ToPrimitive;
Expand Down
2 changes: 1 addition & 1 deletion polars/polars-arrow/src/kernels/rolling/no_nulls/mean.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ use super::sum::SumWindow;
use super::*;
use no_nulls::{rolling_apply_agg_window, RollingAggWindow};

struct MeanWindow<'a, T> {
pub(super) struct MeanWindow<'a, T> {
sum: SumWindow<'a, T>,
}

Expand Down
105 changes: 3 additions & 102 deletions polars/polars-arrow/src/kernels/rolling/no_nulls/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ mod mean;
mod min_max;
mod quantile;
mod sum;
mod variance;

use super::*;
use crate::utils::CustomIterTools;
Expand All @@ -18,6 +19,7 @@ pub use mean::rolling_mean;
pub use min_max::{rolling_max, rolling_min};
pub use quantile::{rolling_median, rolling_quantile};
pub use sum::rolling_sum;
pub use variance::rolling_var;

pub(crate) trait RollingAggWindow<'a, T: NativeType> {
fn new(slice: &'a [T], start: usize, end: usize) -> Self;
Expand Down Expand Up @@ -106,83 +108,26 @@ where
))
}

pub(super) fn rolling_apply<T, K, Fo, Fa>(
values: &[T],
window_size: usize,
min_periods: usize,
det_offsets_fn: Fo,
aggregator: Fa,
) -> ArrayRef
where
Fo: Fn(Idx, WindowSize, Len) -> (Start, End),
Fa: Fn(&[T]) -> K,
K: NativeType,
T: Debug,
{
let len = values.len();
let out = (0..len)
.map(|idx| {
let (start, end) = det_offsets_fn(idx, window_size, len);
let vals = unsafe { values.get_unchecked(start..end) };
aggregator(vals)
})
.collect_trusted::<Vec<K>>();

let validity = create_validity(min_periods, len as usize, window_size, det_offsets_fn);
Arc::new(PrimitiveArray::from_data(
K::PRIMITIVE.into(),
out.into(),
validity.map(|b| b.into()),
))
}

pub(crate) fn compute_var<T>(vals: &[T]) -> T
where
T: Float + std::ops::AddAssign + std::fmt::Debug,
{
let mut count = T::zero();
let mut sum = T::zero();
let mut sum_of_squares = T::zero();

for &val in vals {
sum += val;
sum_of_squares += val * val;
count += T::one();
}

let mean = sum / count;
// apply Bessel's correction
((sum_of_squares / count) - mean * mean) / (count - T::one()) * count
}

fn compute_var_weights<T>(vals: &[T], weights: &[T]) -> T
where
T: Float + std::ops::AddAssign,
{
let weighted_iter = vals.iter().zip(weights).map(|(x, y)| *x * *y);

let mut count = T::zero();
let mut sum = T::zero();
let mut sum_of_squares = T::zero();

for val in weighted_iter {
sum += val;
sum_of_squares += val * val;
count += T::one();
}
let count = NumCast::from(vals.len()).unwrap();

let mean = sum / count;
// apply Bessel's correction
((sum_of_squares / count) - mean * mean) / (count - T::one()) * count
}

pub(crate) fn compute_mean<T>(values: &[T]) -> T
where
T: Float + std::iter::Sum<T>,
{
values.iter().copied().sum::<T>() / T::from(values.len()).unwrap()
}

pub(crate) fn compute_mean_weights<T>(values: &[T], weights: &[T]) -> T
where
T: Float + std::iter::Sum<T>,
Expand All @@ -205,47 +150,3 @@ where
.map(|v| NumCast::from(*v).unwrap())
.collect::<Vec<_>>()
}

pub fn rolling_var<T>(
values: &[T],
window_size: usize,
min_periods: usize,
center: bool,
weights: Option<&[f64]>,
) -> ArrayRef
where
T: NativeType + Float + std::ops::AddAssign,
{
match (center, weights) {
(true, None) => rolling_apply(
values,
window_size,
min_periods,
det_offsets_center,
compute_var,
),
(false, None) => rolling_apply(values, window_size, min_periods, det_offsets, compute_var),
(true, Some(weights)) => {
let weights = coerce_weights(weights);
rolling_apply_weights(
values,
window_size,
min_periods,
det_offsets_center,
compute_var_weights,
&weights,
)
}
(false, Some(weights)) => {
let weights = coerce_weights(weights);
rolling_apply_weights(
values,
window_size,
min_periods,
det_offsets,
compute_var_weights,
&weights,
)
}
}
}
206 changes: 206 additions & 0 deletions polars/polars-arrow/src/kernels/rolling/no_nulls/variance.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,206 @@
use super::mean::MeanWindow;
use super::*;
use no_nulls::{rolling_apply_agg_window, RollingAggWindow};

pub(super) struct SumSquaredWindow<'a, T> {
slice: &'a [T],
sum_of_squares: T,
last_start: usize,
last_end: usize,
}

impl<'a, T: NativeType + IsFloat + std::iter::Sum + AddAssign + SubAssign + Mul<Output = T>>
RollingAggWindow<'a, T> for SumSquaredWindow<'a, T>
{
fn new(slice: &'a [T], start: usize, end: usize) -> Self {
let sum = slice[start..end].iter().map(|v| *v * *v).sum::<T>();
Self {
slice,
sum_of_squares: sum,
last_start: start,
last_end: end,
}
}

unsafe fn update(&mut self, start: usize, end: usize) -> T {
// remove elements that should leave the window
let mut recompute_sum = false;
for idx in self.last_start..start {
// safety
// we are in bounds
let leaving_value = self.slice.get_unchecked(idx);

if T::is_float() && leaving_value.is_nan() {
recompute_sum = true;
break;
}

self.sum_of_squares -= *leaving_value * *leaving_value;
}
self.last_start = start;

// we traverese all values and compute
if T::is_float() && recompute_sum {
self.sum_of_squares = self
.slice
.get_unchecked(start..end)
.iter()
.map(|v| *v * *v)
.sum::<T>();
}
// the max has not left the window, so we only check
// if the entering values are larger
else {
for idx in self.last_end..end {
let entering_value = *self.slice.get_unchecked(idx);
self.sum_of_squares += entering_value * entering_value;
}
}
self.last_end = end;
self.sum_of_squares
}
}

// E[(xi - E[x])^2]
// can be expanded to
// E[x^2] - E[x]^2
struct VarWindow<'a, T> {
mean: MeanWindow<'a, T>,
sum_of_squares: SumSquaredWindow<'a, T>,
}

impl<
'a,
T: NativeType
+ IsFloat
+ std::iter::Sum
+ AddAssign
+ SubAssign
+ Div<Output = T>
+ NumCast
+ One
+ Sub<Output = T>,
> RollingAggWindow<'a, T> for VarWindow<'a, T>
{
fn new(slice: &'a [T], start: usize, end: usize) -> Self {
Self {
mean: MeanWindow::new(slice, start, end),
sum_of_squares: SumSquaredWindow::new(slice, start, end),
}
}

unsafe fn update(&mut self, start: usize, end: usize) -> T {
let count = NumCast::from(end - start).unwrap();
let sum_of_squares = self.sum_of_squares.update(start, end);
let mean_of_squares = sum_of_squares / count;
let mean = self.mean.update(start, end);
let var = mean_of_squares - mean * mean;
// apply Bessel's correction
var / (count - T::one()) * count
}
}

pub fn rolling_var<T>(
values: &[T],
window_size: usize,
min_periods: usize,
center: bool,
weights: Option<&[f64]>,
) -> ArrayRef
where
T: NativeType
+ Float
+ IsFloat
+ std::iter::Sum
+ AddAssign
+ SubAssign
+ Div<Output = T>
+ NumCast
+ One
+ Sub<Output = T>,
{
match (center, weights) {
(true, None) => rolling_apply_agg_window::<VarWindow<_>, _, _>(
values,
window_size,
min_periods,
det_offsets_center,
),
(false, None) => rolling_apply_agg_window::<VarWindow<_>, _, _>(
values,
window_size,
min_periods,
det_offsets,
),
(true, Some(weights)) => {
let weights = coerce_weights(weights);
super::rolling_apply_weights(
values,
window_size,
min_periods,
det_offsets_center,
compute_var_weights,
&weights,
)
}
(false, Some(weights)) => {
let weights = coerce_weights(weights);
super::rolling_apply_weights(
values,
window_size,
min_periods,
det_offsets,
compute_var_weights,
&weights,
)
}
}
}

#[cfg(test)]
mod test {
use super::*;

#[test]
fn test_rolling_var() {
let values = &[1.0f64, 5.0, 3.0, 4.0];

let out = rolling_var(values, 2, 2, false, None);
let out = out.as_any().downcast_ref::<PrimitiveArray<f64>>().unwrap();
let out = out.into_iter().map(|v| v.copied()).collect::<Vec<_>>();
assert_eq!(out, &[None, Some(8.0), Some(2.0), Some(0.5)]);

let out = rolling_var(values, 2, 1, false, None);
let out = out.as_any().downcast_ref::<PrimitiveArray<f64>>().unwrap();
let out = out
.into_iter()
.map(|v| v.copied().unwrap())
.collect::<Vec<_>>();
// we cannot compare nans, so we compare the string values
assert_eq!(
format!("{:?}", out.as_slice()),
format!("{:?}", &[f64::nan(), 8.0, 2.0, 0.5])
);
// test nan handling.
let values = &[-10.0, 2.0, 3.0, f64::nan(), 5.0, 6.0, 7.0];
let out = rolling_var(values, 3, 3, false, None);
let out = out.as_any().downcast_ref::<PrimitiveArray<f64>>().unwrap();
let out = out.into_iter().map(|v| v.copied()).collect::<Vec<_>>();
// we cannot compare nans, so we compare the string values
assert_eq!(
format!("{:?}", out.as_slice()),
format!(
"{:?}",
&[
None,
None,
Some(52.33333333333333),
Some(f64::nan()),
Some(f64::nan()),
Some(f64::nan()),
Some(0.9999999999999964)
]
)
);
}
}

0 comments on commit f14570a

Please sign in to comment.