Skip to content

Commit

Permalink
add rolling options to rolling_apply (#2283)
Browse files Browse the repository at this point in the history
  • Loading branch information
marcvanheerden committed Jan 12, 2022
1 parent ddc0d3f commit 3de5edf
Show file tree
Hide file tree
Showing 12 changed files with 334 additions and 207 deletions.
127 changes: 45 additions & 82 deletions polars/polars-arrow/src/kernels/rolling/no_nulls.rs
Original file line number Diff line number Diff line change
Expand Up @@ -181,67 +181,6 @@ where
values.iter().copied().sum::<T>() / T::from(values.len()).unwrap()
}

pub fn rolling_median<T>(
values: &[T],
window_size: usize,
min_periods: usize,
center: bool,
weights: Option<&[f64]>,
) -> ArrayRef
where
T: NativeType
+ std::iter::Sum<T>
+ std::cmp::PartialOrd
+ num::ToPrimitive
+ NumCast
+ Add<Output = T>
+ Sub<Output = T>
+ Div<Output = T>
+ Mul<Output = T>
+ Zero,
{
match (center, weights) {
(true, None) => rolling_apply_quantile(
values,
0.5,
QuantileInterpolOptions::Linear,
window_size,
min_periods,
det_offsets_center,
compute_quantile,
),
(false, None) => rolling_apply_quantile(
values,
0.5,
QuantileInterpolOptions::Linear,
window_size,
min_periods,
det_offsets,
compute_quantile,
),
(true, Some(weights)) => rolling_apply_convolve_quantile(
values,
0.5,
QuantileInterpolOptions::Linear,
window_size,
min_periods,
det_offsets_center,
compute_quantile,
weights,
),
(false, Some(weights)) => rolling_apply_convolve_quantile(
values,
0.5,
QuantileInterpolOptions::Linear,
window_size,
min_periods,
det_offsets,
compute_quantile,
weights,
),
}
}

pub fn rolling_quantile<T>(
values: &[T],
quantile: f64,
Expand Down Expand Up @@ -668,27 +607,67 @@ mod test {
fn test_rolling_median() {
let values = &[1.0, 2.0, 3.0, 4.0];

let out = rolling_median(values, 2, 2, false, None);
let out = rolling_quantile(
values,
0.5,
QuantileInterpolOptions::Linear,
2,
2,
false,
None,
);
let out = out.as_any().downcast_ref::<PrimitiveArray<f64>>().unwrap();
let out = out.into_iter().map(|v| v.copied()).collect::<Vec<_>>();
assert_eq!(out, &[None, Some(1.5), Some(2.5), Some(3.5)]);

let out = rolling_median(values, 2, 1, false, None);
let out = rolling_quantile(
values,
0.5,
QuantileInterpolOptions::Linear,
2,
1,
false,
None,
);
let out = out.as_any().downcast_ref::<PrimitiveArray<f64>>().unwrap();
let out = out.into_iter().map(|v| v.copied()).collect::<Vec<_>>();
assert_eq!(out, &[Some(1.0), Some(1.5), Some(2.5), Some(3.5)]);

let out = rolling_median(values, 4, 1, false, None);
let out = rolling_quantile(
values,
0.5,
QuantileInterpolOptions::Linear,
4,
1,
false,
None,
);
let out = out.as_any().downcast_ref::<PrimitiveArray<f64>>().unwrap();
let out = out.into_iter().map(|v| v.copied()).collect::<Vec<_>>();
assert_eq!(out, &[Some(1.0), Some(1.5), Some(2.0), Some(2.5)]);

let out = rolling_median(values, 4, 1, true, None);
let out = rolling_quantile(
values,
0.5,
QuantileInterpolOptions::Linear,
4,
1,
true,
None,
);
let out = out.as_any().downcast_ref::<PrimitiveArray<f64>>().unwrap();
let out = out.into_iter().map(|v| v.copied()).collect::<Vec<_>>();
assert_eq!(out, &[Some(1.5), Some(2.0), Some(2.5), Some(3.0)]);

let out = rolling_median(values, 4, 4, true, None);
let out = rolling_quantile(
values,
0.5,
QuantileInterpolOptions::Linear,
4,
4,
true,
None,
);
let out = out.as_any().downcast_ref::<PrimitiveArray<f64>>().unwrap();
let out = out.into_iter().map(|v| v.copied()).collect::<Vec<_>>();
assert_eq!(out, &[None, None, Some(2.5), None]);
Expand Down Expand Up @@ -723,21 +702,5 @@ mod test {
let out2 = out2.into_iter().map(|v| v.copied()).collect::<Vec<_>>();
assert_eq!(out1, out2);
}

let out1 = rolling_median(values, 2, 2, false, None);
let out1 = out1.as_any().downcast_ref::<PrimitiveArray<f64>>().unwrap();
let out1 = out1.into_iter().map(|v| v.copied()).collect::<Vec<_>>();
let out2 = rolling_quantile(
values,
0.5,
QuantileInterpolOptions::Linear,
2,
2,
false,
None,
);
let out2 = out2.as_any().downcast_ref::<PrimitiveArray<f64>>().unwrap();
let out2 = out2.into_iter().map(|v| v.copied()).collect::<Vec<_>>();
assert_eq!(out1, out2);
}
}
76 changes: 5 additions & 71 deletions polars/polars-arrow/src/kernels/rolling/nulls.rs
Original file line number Diff line number Diff line change
Expand Up @@ -428,56 +428,6 @@ where
}
}

pub fn rolling_median<T>(
arr: &PrimitiveArray<T>,
window_size: usize,
min_periods: usize,
center: bool,
weights: Option<&[f64]>,
) -> ArrayRef
where
T: NativeType
+ std::iter::Sum
+ Zero
+ AddAssign
+ Copy
+ std::cmp::PartialOrd
+ num::ToPrimitive
+ NumCast
+ Default
+ Add<Output = T>
+ Sub<Output = T>
+ Div<Output = T>
+ Mul<Output = T>,
{
if weights.is_some() {
panic!("weights not yet supported on array with null values")
}
if center {
rolling_apply_quantile(
arr.values().as_slice(),
arr.validity().as_ref().unwrap(),
0.5,
QuantileInterpolOptions::Linear,
window_size,
min_periods,
det_offsets_center,
compute_quantile,
)
} else {
rolling_apply_quantile(
arr.values().as_slice(),
arr.validity().as_ref().unwrap(),
0.5,
QuantileInterpolOptions::Linear,
window_size,
min_periods,
det_offsets,
compute_quantile,
)
}
}

pub fn rolling_quantile<T>(
arr: &PrimitiveArray<T>,
quantile: f64,
Expand Down Expand Up @@ -682,27 +632,27 @@ mod test {
Some(Bitmap::from(&[true, false, true, true])),
);

let out = rolling_median(arr, 2, 2, false, None);
let out = rolling_quantile(arr, 0.5, QuantileInterpolOptions::Linear, 2, 2, false, None);
let out = out.as_any().downcast_ref::<PrimitiveArray<f64>>().unwrap();
let out = out.into_iter().map(|v| v.copied()).collect::<Vec<_>>();
assert_eq!(out, &[None, None, None, Some(3.5)]);

let out = rolling_median(arr, 2, 1, false, None);
let out = rolling_quantile(arr, 0.5, QuantileInterpolOptions::Linear, 2, 1, false, None);
let out = out.as_any().downcast_ref::<PrimitiveArray<f64>>().unwrap();
let out = out.into_iter().map(|v| v.copied()).collect::<Vec<_>>();
assert_eq!(out, &[Some(1.0), Some(1.0), Some(3.0), Some(3.5)]);

let out = rolling_median(arr, 4, 1, false, None);
let out = rolling_quantile(arr, 0.5, QuantileInterpolOptions::Linear, 4, 1, false, None);
let out = out.as_any().downcast_ref::<PrimitiveArray<f64>>().unwrap();
let out = out.into_iter().map(|v| v.copied()).collect::<Vec<_>>();
assert_eq!(out, &[Some(1.0), Some(1.0), Some(2.0), Some(3.0)]);

let out = rolling_median(arr, 4, 1, true, None);
let out = rolling_quantile(arr, 0.5, QuantileInterpolOptions::Linear, 4, 1, true, None);
let out = out.as_any().downcast_ref::<PrimitiveArray<f64>>().unwrap();
let out = out.into_iter().map(|v| v.copied()).collect::<Vec<_>>();
assert_eq!(out, &[Some(1.0), Some(2.0), Some(3.0), Some(3.5)]);

let out = rolling_median(arr, 4, 4, true, None);
let out = rolling_quantile(arr, 0.5, QuantileInterpolOptions::Linear, 4, 4, true, None);
let out = out.as_any().downcast_ref::<PrimitiveArray<f64>>().unwrap();
let out = out.into_iter().map(|v| v.copied()).collect::<Vec<_>>();
assert_eq!(out, &[None, None, None, None]);
Expand Down Expand Up @@ -767,21 +717,5 @@ mod test {
let out2 = out2.into_iter().map(|v| v.copied()).collect::<Vec<_>>();
assert_eq!(out1, out2);
}

let out1 = rolling_median(values, 2, 1, false, None);
let out1 = out1.as_any().downcast_ref::<PrimitiveArray<f64>>().unwrap();
let out1 = out1.into_iter().map(|v| v.copied()).collect::<Vec<_>>();
let out2 = rolling_quantile(
values,
0.5,
QuantileInterpolOptions::Linear,
2,
1,
false,
None,
);
let out2 = out2.as_any().downcast_ref::<PrimitiveArray<f64>>().unwrap();
let out2 = out2.into_iter().map(|v| v.copied()).collect::<Vec<_>>();
assert_eq!(out1, out2);
}
}
6 changes: 5 additions & 1 deletion polars/polars-core/src/chunked_array/ops/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -139,7 +139,11 @@ pub trait ChunkBytes {
/// This likely is a bit slower than ChunkWindow
#[cfg(feature = "rolling_window")]
pub trait ChunkRollApply {
fn rolling_apply(&self, _window_size: usize, _f: &dyn Fn(&Series) -> Series) -> Result<Self>
fn rolling_apply(
&self,
_f: &dyn Fn(&Series) -> Series,
_options: RollingOptions,
) -> Result<Series>
where
Self: Sized,
{
Expand Down

0 comments on commit 3de5edf

Please sign in to comment.