Skip to content

Commit

Permalink
fix weighted rolling aggregates except quantile
Browse files Browse the repository at this point in the history
  • Loading branch information
marcvanheerden authored and ritchie46 committed Feb 2, 2022
1 parent 621b5c7 commit c26b638
Showing 1 changed file with 108 additions and 20 deletions.
128 changes: 108 additions & 20 deletions polars/polars-arrow/src/kernels/rolling/no_nulls.rs
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,37 @@ where
))
}

fn rolling_apply_weights<Fo, Fa>(
values: &[f64],
window_size: usize,
min_periods: usize,
det_offsets_fn: Fo,
aggregator: Fa,
weights: &[f64],
) -> ArrayRef
where
Fo: Fn(Idx, WindowSize, Len) -> (Start, End),
Fa: Fn(&[f64], &[f64]) -> f64,
{
assert_eq!(weights.len(), window_size);
let len = values.len();
let out = (0..len)
.map(|idx| {
let (start, end) = det_offsets_fn(idx, window_size, len);
let vals = unsafe { values.get_unchecked(start..end) };

aggregator(vals, weights)
})
.collect_trusted::<Vec<f64>>();

let validity = create_validity(min_periods, len as usize, window_size, det_offsets_fn);
Arc::new(PrimitiveArray::from_data(
DataType::Float64,
out.into(),
validity.map(|b| b.into()),
))
}

fn rolling_apply<T, K, Fo, Fa>(
values: &[T],
window_size: usize,
Expand Down Expand Up @@ -174,13 +205,39 @@ where
sum / (len - T::one())
}

fn compute_var_weights<T>(vals: &[T], weights: &[T]) -> T
where
T: Float + std::ops::AddAssign,
{
let weighted_iter = vals.iter().zip(weights).map(|(x, y)| *x * *y);

let mut x = T::zero();
let mut xsquare = T::zero();
let mut len = T::zero();

for val in weighted_iter {
x += val;
xsquare += val * val;
len += T::one();
}

((xsquare / len) - (x / len) * (x / len)) / (len - T::one()) * len
}

pub(crate) fn compute_mean<T>(values: &[T]) -> T
where
T: Float + std::iter::Sum<T>,
{
values.iter().copied().sum::<T>() / T::from(values.len()).unwrap()
}

pub(crate) fn compute_mean_weights<T>(values: &[T], weights: &[T]) -> T
where
T: Float + std::iter::Sum<T>,
{
values.iter().zip(weights).map(|(v, w)| *v * *w).sum::<T>() / T::from(values.len()).unwrap()
}

pub fn rolling_quantile<T>(
values: &[T],
quantile: f64,
Expand Down Expand Up @@ -251,6 +308,13 @@ where
values.iter().copied().sum()
}

pub(crate) fn compute_sum_weights<T>(values: &[T], weights: &[T]) -> T
where
T: std::iter::Sum<T> + Copy + std::ops::Mul<Output = T>,
{
values.iter().zip(weights).map(|(v, w)| *v * *w).sum()
}

pub(crate) fn compute_quantile<T>(
values: &[T],
quantile: f64,
Expand Down Expand Up @@ -325,6 +389,18 @@ where
.unwrap()
}

pub(crate) fn compute_min_weights<T>(values: &[T], weights: &[T]) -> T
where
T: NativeType + PartialOrd + std::ops::Mul<Output = T>,
{
values
.iter()
.zip(weights)
.map(|(v, w)| *v * *w)
.min_by(|a, b| a.partial_cmp(b).unwrap())
.unwrap()
}

pub(crate) fn compute_max<T>(values: &[T]) -> T
where
T: NativeType + PartialOrd,
Expand All @@ -336,6 +412,18 @@ where
.unwrap()
}

pub(crate) fn compute_max_weights<T>(values: &[T], weights: &[T]) -> T
where
T: NativeType + PartialOrd + std::ops::Mul<Output = T>,
{
values
.iter()
.zip(weights)
.map(|(v, w)| *v * *w)
.max_by(|a, b| a.partial_cmp(b).unwrap())
.unwrap()
}

fn as_floats<T>(values: &[T]) -> &[f64]
where
T: Any,
Expand Down Expand Up @@ -370,23 +458,23 @@ where
(false, None) => rolling_apply(values, window_size, min_periods, det_offsets, compute_mean),
(true, Some(weights)) => {
let values = as_floats(values);
rolling_apply_convolve(
rolling_apply_weights(
values,
window_size,
min_periods,
det_offsets_center,
compute_mean,
compute_mean_weights,
weights,
)
}
(false, Some(weights)) => {
let values = as_floats(values);
rolling_apply_convolve(
rolling_apply_weights(
values,
window_size,
min_periods,
det_offsets,
compute_mean,
compute_mean_weights,
weights,
)
}
Expand Down Expand Up @@ -414,23 +502,23 @@ where
(false, None) => rolling_apply(values, window_size, min_periods, det_offsets, compute_min),
(true, Some(weights)) => {
let values = as_floats(values);
rolling_apply_convolve(
rolling_apply_weights(
values,
window_size,
min_periods,
det_offsets_center,
compute_min,
compute_min_weights,
weights,
)
}
(false, Some(weights)) => {
let values = as_floats(values);
rolling_apply_convolve(
rolling_apply_weights(
values,
window_size,
min_periods,
det_offsets,
compute_min,
compute_min_weights,
weights,
)
}
Expand Down Expand Up @@ -458,23 +546,23 @@ where
(false, None) => rolling_apply(values, window_size, min_periods, det_offsets, compute_max),
(true, Some(weights)) => {
let values = as_floats(values);
rolling_apply_convolve(
rolling_apply_weights(
values,
window_size,
min_periods,
det_offsets_center,
compute_max,
compute_max_weights,
weights,
)
}
(false, Some(weights)) => {
let values = as_floats(values);
rolling_apply_convolve(
rolling_apply_weights(
values,
window_size,
min_periods,
det_offsets,
compute_max,
compute_max_weights,
weights,
)
}
Expand Down Expand Up @@ -502,23 +590,23 @@ where
(false, None) => rolling_apply(values, window_size, min_periods, det_offsets, compute_var),
(true, Some(weights)) => {
let values = as_floats(values);
rolling_apply_convolve(
rolling_apply_weights(
values,
window_size,
min_periods,
det_offsets_center,
compute_var,
compute_var_weights,
weights,
)
}
(false, Some(weights)) => {
let values = as_floats(values);
rolling_apply_convolve(
rolling_apply_weights(
values,
window_size,
min_periods,
det_offsets,
compute_var,
compute_var_weights,
weights,
)
}
Expand Down Expand Up @@ -546,23 +634,23 @@ where
(false, None) => rolling_apply(values, window_size, min_periods, det_offsets, compute_sum),
(true, Some(weights)) => {
let values = as_floats(values);
rolling_apply_convolve(
rolling_apply_weights(
values,
window_size,
min_periods,
det_offsets_center,
compute_sum,
compute_sum_weights,
weights,
)
}
(false, Some(weights)) => {
let values = as_floats(values);
rolling_apply_convolve(
rolling_apply_weights(
values,
window_size,
min_periods,
det_offsets,
compute_sum,
compute_sum_weights,
weights,
)
}
Expand Down

0 comments on commit c26b638

Please sign in to comment.