Skip to content

Commit

Permalink
add rolling options for rolling median and quantile (#2229)
Browse files Browse the repository at this point in the history
  • Loading branch information
marcvanheerden committed Jan 3, 2022
1 parent f6545db commit 16ab34c
Show file tree
Hide file tree
Showing 33 changed files with 1,213 additions and 55 deletions.
357 changes: 356 additions & 1 deletion polars/polars-arrow/src/kernels/rolling/no_nulls.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,27 @@ use crate::utils::CustomIterTools;
use arrow::array::{ArrayRef, PrimitiveArray};
use arrow::datatypes::DataType;
use arrow::types::NativeType;
use num::Float;
use num::{Float, NumCast, ToPrimitive, Zero};
use std::any::Any;
use std::fmt::Debug;
use std::ops::{Add, Div, Mul, Sub};
use std::sync::Arc;

#[derive(Clone, Copy, PartialEq, Debug)]
pub enum QuantileInterpolOptions {
Nearest,
Lower,
Higher,
Midpoint,
Linear,
}

impl Default for QuantileInterpolOptions {
fn default() -> Self {
QuantileInterpolOptions::Nearest
}
}

fn rolling_apply_convolve<Fo, Fa>(
values: &[f64],
window_size: usize,
Expand Down Expand Up @@ -73,6 +89,76 @@ where
))
}

fn rolling_apply_quantile<T, Fo, Fa>(
values: &[T],
quantile: f64,
interpolation: QuantileInterpolOptions,
window_size: usize,
min_periods: usize,
det_offsets_fn: Fo,
aggregator: Fa,
) -> ArrayRef
where
Fo: Fn(Idx, WindowSize, Len) -> (Start, End),
Fa: Fn(&[T], f64, QuantileInterpolOptions) -> T,
T: Debug + NativeType,
{
let len = values.len();
let out = (0..len)
.map(|idx| {
let (start, end) = det_offsets_fn(idx, window_size, len);
let vals = unsafe { values.get_unchecked(start..end) };
aggregator(vals, quantile, interpolation)
})
.collect_trusted::<Vec<T>>();

let validity = create_validity(min_periods, len as usize, window_size, det_offsets_fn);
Arc::new(PrimitiveArray::from_data(
T::PRIMITIVE.into(),
out.into(),
validity.map(|b| b.into()),
))
}

#[allow(clippy::too_many_arguments)]
fn rolling_apply_convolve_quantile<T, Fo, Fa>(
values: &[T],
quantile: f64,
interpolation: QuantileInterpolOptions,
window_size: usize,
min_periods: usize,
det_offsets_fn: Fo,
aggregator: Fa,
weights: &[f64],
) -> ArrayRef
where
Fo: Fn(Idx, WindowSize, Len) -> (Start, End),
Fa: Fn(&[T], f64, QuantileInterpolOptions) -> T,
T: Debug + NativeType + Mul<Output = T> + NumCast + ToPrimitive + Zero,
{
assert_eq!(weights.len(), window_size);
let mut buf = vec![T::zero(); window_size];
let len = values.len();
let out = (0..len)
.map(|idx| {
let (start, end) = det_offsets_fn(idx, window_size, len);
let vals = unsafe { values.get_unchecked(start..end) };
buf.iter_mut()
.zip(vals.iter().zip(weights))
.for_each(|(b, (v, w))| *b = *v * NumCast::from(*w).unwrap());

aggregator(&buf, quantile, interpolation)
})
.collect_trusted::<Vec<T>>();

let validity = create_validity(min_periods, len as usize, window_size, det_offsets_fn);
Arc::new(PrimitiveArray::from_data(
T::PRIMITIVE.into(),
out.into(),
validity.map(|b| b.into()),
))
}

pub(crate) fn compute_var<T>(vals: &[T]) -> T
where
T: Float + std::iter::Sum,
Expand All @@ -95,13 +181,205 @@ where
values.iter().copied().sum::<T>() / T::from(values.len()).unwrap()
}

pub fn rolling_median<T>(
values: &[T],
window_size: usize,
min_periods: usize,
center: bool,
weights: Option<&[f64]>,
) -> ArrayRef
where
T: NativeType
+ std::iter::Sum<T>
+ std::cmp::PartialOrd
+ num::ToPrimitive
+ NumCast
+ Add<Output = T>
+ Sub<Output = T>
+ Div<Output = T>
+ Mul<Output = T>,
{
match (center, weights) {
(true, None) => rolling_apply_quantile(
values,
0.5,
QuantileInterpolOptions::Linear,
window_size,
min_periods,
det_offsets_center,
compute_quantile,
),
(false, None) => rolling_apply_quantile(
values,
0.5,
QuantileInterpolOptions::Linear,
window_size,
min_periods,
det_offsets,
compute_quantile,
),
(true, Some(weights)) => {
let values = as_floats(values);
rolling_apply_convolve_quantile(
values,
0.5,
QuantileInterpolOptions::Linear,
window_size,
min_periods,
det_offsets_center,
compute_quantile,
weights,
)
}
(false, Some(weights)) => {
let values = as_floats(values);
rolling_apply_convolve_quantile(
values,
0.5,
QuantileInterpolOptions::Linear,
window_size,
min_periods,
det_offsets,
compute_quantile,
weights,
)
}
}
}

pub fn rolling_quantile<T>(
values: &[T],
quantile: f64,
interpolation: QuantileInterpolOptions,
window_size: usize,
min_periods: usize,
center: bool,
weights: Option<&[f64]>,
) -> ArrayRef
where
T: NativeType
+ std::iter::Sum<T>
+ std::cmp::PartialOrd
+ num::ToPrimitive
+ NumCast
+ Add<Output = T>
+ Sub<Output = T>
+ Div<Output = T>
+ Mul<Output = T>
+ Zero,
{
match (center, weights) {
(true, None) => rolling_apply_quantile(
values,
quantile,
interpolation,
window_size,
min_periods,
det_offsets_center,
compute_quantile,
),
(false, None) => rolling_apply_quantile(
values,
quantile,
interpolation,
window_size,
min_periods,
det_offsets,
compute_quantile,
),
(true, Some(weights)) => rolling_apply_convolve_quantile(
values,
quantile,
interpolation,
window_size,
min_periods,
det_offsets_center,
compute_quantile,
weights,
),
(false, Some(weights)) => rolling_apply_convolve_quantile(
values,
quantile,
interpolation,
window_size,
min_periods,
det_offsets,
compute_quantile,
weights,
),
}
}

pub(crate) fn compute_sum<T>(values: &[T]) -> T
where
T: std::iter::Sum<T> + Copy,
{
values.iter().copied().sum()
}

pub(crate) fn compute_quantile<T>(
values: &[T],
quantile: f64,
interpolation: QuantileInterpolOptions,
) -> T
where
T: std::iter::Sum<T>
+ Copy
+ std::cmp::PartialOrd
+ num::ToPrimitive
+ NumCast
+ Add<Output = T>
+ Sub<Output = T>
+ Div<Output = T>
+ Mul<Output = T>,
{
if !(0.0..=1.0).contains(&quantile) {
panic!("quantile should be between 0.0 and 1.0");
}

let mut vals: Vec<T> = values
.iter()
.copied()
.map(|x| NumCast::from(x).unwrap())
.collect();
vals.sort_by(|a, b| a.partial_cmp(b).unwrap());

let length = vals.len();

let mut idx = match interpolation {
QuantileInterpolOptions::Nearest => ((length as f64) * quantile) as usize,
QuantileInterpolOptions::Lower
| QuantileInterpolOptions::Midpoint
| QuantileInterpolOptions::Linear => ((length as f64 - 1.0) * quantile).floor() as usize,
QuantileInterpolOptions::Higher => ((length as f64 - 1.0) * quantile).ceil() as usize,
};

idx = std::cmp::min(idx, length - 1);

match interpolation {
QuantileInterpolOptions::Midpoint => {
let top_idx = ((length as f64 - 1.0) * quantile).ceil() as usize;
if top_idx == idx {
vals[idx]
} else {
(vals[idx] + vals[idx + 1]) / T::from::<f64>(2.0f64).unwrap()
}
}
QuantileInterpolOptions::Linear => {
let float_idx = (length as f64 - 1.0) * quantile;
let top_idx = f64::ceil(float_idx) as usize;

if top_idx == idx {
vals[idx]
} else {
let proportion = T::from(float_idx - idx as f64).unwrap();
proportion * (vals[top_idx] - vals[idx]) + vals[idx]
}
}
_ => vals[idx],
}
}

pub(crate) fn compute_min<T>(values: &[T]) -> T
where
T: NativeType + PartialOrd,
Expand Down Expand Up @@ -390,4 +668,81 @@ mod test {
let out = out.into_iter().map(|v| v.copied()).collect::<Vec<_>>();
assert_eq!(out, &[None, None, Some(10.0), None]);
}

#[test]
fn test_rolling_median() {
let values = &[1.0, 2.0, 3.0, 4.0];

let out = rolling_median(values, 2, 2, false, None);
let out = out.as_any().downcast_ref::<PrimitiveArray<f64>>().unwrap();
let out = out.into_iter().map(|v| v.copied()).collect::<Vec<_>>();
assert_eq!(out, &[None, Some(1.5), Some(2.5), Some(3.5)]);

let out = rolling_median(values, 2, 1, false, None);
let out = out.as_any().downcast_ref::<PrimitiveArray<f64>>().unwrap();
let out = out.into_iter().map(|v| v.copied()).collect::<Vec<_>>();
assert_eq!(out, &[Some(1.0), Some(1.5), Some(2.5), Some(3.5)]);

let out = rolling_median(values, 4, 1, false, None);
let out = out.as_any().downcast_ref::<PrimitiveArray<f64>>().unwrap();
let out = out.into_iter().map(|v| v.copied()).collect::<Vec<_>>();
assert_eq!(out, &[Some(1.0), Some(1.5), Some(2.0), Some(2.5)]);

let out = rolling_median(values, 4, 1, true, None);
let out = out.as_any().downcast_ref::<PrimitiveArray<f64>>().unwrap();
let out = out.into_iter().map(|v| v.copied()).collect::<Vec<_>>();
assert_eq!(out, &[Some(1.5), Some(2.0), Some(2.5), Some(3.0)]);

let out = rolling_median(values, 4, 4, true, None);
let out = out.as_any().downcast_ref::<PrimitiveArray<f64>>().unwrap();
let out = out.into_iter().map(|v| v.copied()).collect::<Vec<_>>();
assert_eq!(out, &[None, None, Some(2.5), None]);
}

#[test]
fn test_rolling_quantile_limits() {
let values = &[1.0, 2.0, 3.0, 4.0];

let interpol_options = vec![
QuantileInterpolOptions::Lower,
QuantileInterpolOptions::Higher,
QuantileInterpolOptions::Nearest,
QuantileInterpolOptions::Midpoint,
QuantileInterpolOptions::Linear,
];

for interpol in interpol_options {
let out1 = rolling_min(values, 2, 2, false, None);
let out1 = out1.as_any().downcast_ref::<PrimitiveArray<f64>>().unwrap();
let out1 = out1.into_iter().map(|v| v.copied()).collect::<Vec<_>>();
let out2 = rolling_quantile(values, 0.0, interpol, 2, 2, false, None);
let out2 = out2.as_any().downcast_ref::<PrimitiveArray<f64>>().unwrap();
let out2 = out2.into_iter().map(|v| v.copied()).collect::<Vec<_>>();
assert_eq!(out1, out2);

let out1 = rolling_max(values, 2, 2, false, None);
let out1 = out1.as_any().downcast_ref::<PrimitiveArray<f64>>().unwrap();
let out1 = out1.into_iter().map(|v| v.copied()).collect::<Vec<_>>();
let out2 = rolling_quantile(values, 1.0, interpol, 2, 2, false, None);
let out2 = out2.as_any().downcast_ref::<PrimitiveArray<f64>>().unwrap();
let out2 = out2.into_iter().map(|v| v.copied()).collect::<Vec<_>>();
assert_eq!(out1, out2);
}

let out1 = rolling_median(values, 2, 2, false, None);
let out1 = out1.as_any().downcast_ref::<PrimitiveArray<f64>>().unwrap();
let out1 = out1.into_iter().map(|v| v.copied()).collect::<Vec<_>>();
let out2 = rolling_quantile(
values,
0.5,
QuantileInterpolOptions::Linear,
2,
2,
false,
None,
);
let out2 = out2.as_any().downcast_ref::<PrimitiveArray<f64>>().unwrap();
let out2 = out2.into_iter().map(|v| v.copied()).collect::<Vec<_>>();
assert_eq!(out1, out2);
}
}

0 comments on commit 16ab34c

Please sign in to comment.