Skip to content

Commit

Permalink
improve rolling_window performance
Browse files Browse the repository at this point in the history
  • Loading branch information
ritchie46 committed Sep 18, 2021
1 parent fe00bd7 commit ef911b4
Showing 1 changed file with 23 additions and 25 deletions.
48 changes: 23 additions & 25 deletions polars/polars-core/src/chunked_array/ops/rolling_window.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
use crate::prelude::*;
use arrow::array::{Array, PrimitiveArray};
use num::{Bounded, NumCast, One, Zero};
use polars_arrow::utils::CustomIterTools;
use std::ops::{Add, Div, Mul, Rem, Sub};

/// a fold function to compute the sum. Returns a Null if there is a single null in the window
Expand Down Expand Up @@ -114,13 +115,11 @@ where
fn update_state<T>(
// (window , oldest_value_idx, amount_Some)
state: &mut (Vec<Option<T>>, u32, u32),
// count of the loop
idx_count: u32,
// new value
opt_v: Option<T>,
// size of the window
window_size: u32,
) -> u32 {
) {
let (window, idx, _) = state;
let old_value = &mut window[*idx as usize];
let mut new_val = opt_v;
Expand All @@ -134,28 +133,19 @@ fn update_state<T>(

std::mem::swap(old_value, &mut new_val);

let idx_count = idx_count + 1;
state.1 = idx_count % window_size;
idx_count
// this removes an expensive modulo
state.1 += 1;
if state.1 == window_size {
state.1 = 0
}
}

/// Apply weight to the current window and accumulate with a `fold_fn`.
fn apply_window<T, F>(
weight: Option<&[T]>,
window: &[Option<T>],
fold_fn: F,
init_fold: InitFold,
) -> Option<T>
fn apply_window<T, F>(weight: Option<&[T]>, window: &[Option<T>], fold_fn: F, init: T) -> Option<T>
where
T: Copy + Add<Output = T> + Zero + Mul<Output = T> + Bounded,
F: Fn(Option<T>, Option<T>) -> Option<T>,
{
let init = match init_fold {
InitFold::Zero => Zero::zero(),
InitFold::Min => Bounded::min_value(),
InitFold::Max => Bounded::max_value(),
};

match weight {
None => window.iter().copied().fold(Some(init), fold_fn),
Some(weight) => rescale_window(window, weight)
Expand Down Expand Up @@ -193,32 +183,40 @@ where
{
let weight: Option<Vec<T::Native>> = weight.map(weight_to_native);
let window = vec![None; window_size as usize];
let mut idx_count = 0;

let init = match init_fold {
InitFold::Zero => Zero::zero(),
InitFold::Min => Bounded::min_value(),
InitFold::Max => Bounded::max_value(),
};

if ca.null_count() == 0 {
ca.into_no_null_iter()
.scan((window, 0u32, 0u32), |state, v| {
idx_count = update_state(state, idx_count, Some(v), window_size);
update_state(state, Some(v), window_size);
let (window, _, some_count) = state;
if *some_count < min_periods {
Some(None)
} else {
let sum = apply_window(weight.as_deref(), window, fold_fn, init_fold);
let sum = apply_window(weight.as_deref(), window, fold_fn, init);
Some(sum)
}
})
.collect()
.trust_my_length(ca.len())
.collect_trusted()
} else {
ca.into_iter()
.scan((window, 0u32, 0u32), |state, opt_v| {
idx_count = update_state(state, idx_count, opt_v, window_size);
update_state(state, opt_v, window_size);
let (window, _, some_count) = state;
if *some_count < min_periods {
Some(None)
} else {
Some(apply_window(weight.as_deref(), window, fold_fn, init_fold))
Some(apply_window(weight.as_deref(), window, fold_fn, init))
}
})
.collect()
.trust_my_length(ca.len())
.collect_trusted()
}
}

Expand Down

0 comments on commit ef911b4

Please sign in to comment.