Skip to content

Commit

Permalink
improve rolling_{min/max/sum/mean} prerformance ~3.4x (#3444)
Browse files Browse the repository at this point in the history
  • Loading branch information
ritchie46 committed May 20, 2022
1 parent 24ddb90 commit 24cc2ef
Show file tree
Hide file tree
Showing 9 changed files with 737 additions and 342 deletions.
72 changes: 72 additions & 0 deletions polars/polars-arrow/src/kernels/rolling/mean_no_nulls.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
use super::*;
use crate::kernels::rolling::sum_min_max_no_nulls::SumWindow;
use no_nulls::{rolling_apply_agg_window, RollingAggWindow};

struct MeanWindow<'a, T> {
sum: SumWindow<'a, T>,
}

impl<
'a,
T: NativeType + IsFloat + std::iter::Sum + AddAssign + SubAssign + Div<Output = T> + NumCast,
> RollingAggWindow<'a, T> for MeanWindow<'a, T>
{
fn new(slice: &'a [T], start: usize, end: usize) -> Self {
Self {
sum: SumWindow::new(slice, start, end),
}
}

unsafe fn update(&mut self, start: usize, end: usize) -> T {
let sum = self.sum.update(start, end);
sum / NumCast::from(end - start).unwrap()
}
}

pub fn rolling_mean<T>(
values: &[T],
window_size: usize,
min_periods: usize,
center: bool,
weights: Option<&[f64]>,
) -> ArrayRef
where
T: NativeType + Float + std::iter::Sum<T> + SubAssign + AddAssign + IsFloat,
{
match (center, weights) {
(true, None) => rolling_apply_agg_window::<MeanWindow<_>, _, _>(
values,
window_size,
min_periods,
det_offsets_center,
),
(false, None) => rolling_apply_agg_window::<MeanWindow<_>, _, _>(
values,
window_size,
min_periods,
det_offsets,
),
(true, Some(weights)) => {
let weights = no_nulls::coerce_weights(weights);
no_nulls::rolling_apply_weights(
values,
window_size,
min_periods,
det_offsets_center,
no_nulls::compute_mean_weights,
&weights,
)
}
(false, Some(weights)) => {
let weights = no_nulls::coerce_weights(weights);
no_nulls::rolling_apply_weights(
values,
window_size,
min_periods,
det_offsets,
no_nulls::compute_mean_weights,
&weights,
)
}
}
}
47 changes: 43 additions & 4 deletions polars/polars-arrow/src/kernels/rolling/mod.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
mod mean_no_nulls;
pub mod no_nulls;
pub mod nulls;
pub mod quantile_no_nulls;
pub mod quantile_nulls;
mod quantile_no_nulls;
mod quantile_nulls;
mod sum_min_max_no_nulls;
mod window;

use crate::data_types::IsFloat;
Expand All @@ -14,8 +16,7 @@ use arrow::types::NativeType;
use num::ToPrimitive;
use num::{Bounded, Float, NumCast, One, Zero};
use std::cmp::Ordering;
use std::ops::AddAssign;
use std::ops::{Add, Div, Mul, Sub};
use std::ops::{Add, AddAssign, Div, Mul, Sub, SubAssign};
use std::sync::Arc;
use window::*;

Expand All @@ -25,6 +26,44 @@ type Idx = usize;
type WindowSize = usize;
type Len = usize;

fn compare_fn_nan_min<T>(a: &T, b: &T) -> Ordering
where
T: PartialOrd + IsFloat + NativeType,
{
if T::is_float() {
match (a.is_nan(), b.is_nan()) {
// safety: we checked nans
(false, false) => unsafe { a.partial_cmp(b).unwrap_unchecked() },
(true, true) => Ordering::Equal,
(true, false) => Ordering::Less,
(false, true) => Ordering::Greater,
}
} else {
// Safety:
// all integers are Ord
unsafe { a.partial_cmp(b).unwrap_unchecked() }
}
}

fn compare_fn_nan_max<T>(a: &T, b: &T) -> Ordering
where
T: PartialOrd + IsFloat + NativeType,
{
if T::is_float() {
match (a.is_nan(), b.is_nan()) {
// safety: we checked nans
(false, false) => unsafe { a.partial_cmp(b).unwrap_unchecked() },
(true, true) => Ordering::Equal,
(true, false) => Ordering::Greater,
(false, true) => Ordering::Less,
}
} else {
// Safety:
// all integers are Ord
unsafe { a.partial_cmp(b).unwrap_unchecked() }
}
}

fn det_offsets(i: Idx, window_size: WindowSize, _len: Len) -> (usize, usize) {
(i.saturating_sub(window_size - 1), i + 1)
}
Expand Down

0 comments on commit 24cc2ef

Please sign in to comment.