Skip to content

Commit

Permalink
fix panics with weighted rolling median (#2259)
Browse files Browse the repository at this point in the history
  • Loading branch information
marcvanheerden committed Jan 4, 2022
1 parent 16ab34c commit 4380682
Show file tree
Hide file tree
Showing 2 changed files with 58 additions and 47 deletions.
49 changes: 22 additions & 27 deletions polars/polars-arrow/src/kernels/rolling/no_nulls.rs
Original file line number Diff line number Diff line change
Expand Up @@ -197,7 +197,8 @@ where
+ Add<Output = T>
+ Sub<Output = T>
+ Div<Output = T>
+ Mul<Output = T>,
+ Mul<Output = T>
+ Zero,
{
match (center, weights) {
(true, None) => rolling_apply_quantile(
Expand All @@ -218,32 +219,26 @@ where
det_offsets,
compute_quantile,
),
(true, Some(weights)) => {
let values = as_floats(values);
rolling_apply_convolve_quantile(
values,
0.5,
QuantileInterpolOptions::Linear,
window_size,
min_periods,
det_offsets_center,
compute_quantile,
weights,
)
}
(false, Some(weights)) => {
let values = as_floats(values);
rolling_apply_convolve_quantile(
values,
0.5,
QuantileInterpolOptions::Linear,
window_size,
min_periods,
det_offsets,
compute_quantile,
weights,
)
}
(true, Some(weights)) => rolling_apply_convolve_quantile(
values,
0.5,
QuantileInterpolOptions::Linear,
window_size,
min_periods,
det_offsets_center,
compute_quantile,
weights,
),
(false, Some(weights)) => rolling_apply_convolve_quantile(
values,
0.5,
QuantileInterpolOptions::Linear,
window_size,
min_periods,
det_offsets,
compute_quantile,
weights,
),
}
}

Expand Down
56 changes: 36 additions & 20 deletions polars/polars-core/src/chunked_array/ops/rolling_window.rs
Original file line number Diff line number Diff line change
Expand Up @@ -123,13 +123,6 @@ mod inner_mod {
check_input(options.window_size, options.min_periods)?;
let ca = self.rechunk();

if options.weights.is_some()
&& !matches!(self.dtype(), DataType::Float64 | DataType::Float32)
{
let s = ca.cast(&DataType::Float64).unwrap();
return s.rolling_median(options);
}

let arr = ca.downcast_iter().next().unwrap();
let arr = match self.has_validity() {
false => rolling::no_nulls::rolling_median(
Expand Down Expand Up @@ -170,13 +163,6 @@ mod inner_mod {
check_input(options.window_size, options.min_periods)?;
let ca = self.rechunk();

if options.weights.is_some()
&& !matches!(self.dtype(), DataType::Float64 | DataType::Float32)
{
let s = ca.cast(&DataType::Float64).unwrap();
return s.rolling_quantile(quantile, interpolation, options);
}

let arr = ca.downcast_iter().next().unwrap();
let arr = match self.has_validity() {
false => rolling::no_nulls::rolling_quantile(
Expand Down Expand Up @@ -665,10 +651,19 @@ mod test {
#[test]
fn test_median_quantile_types() {
let ca = Int32Chunked::new("foo", &[1, 2, 3, 2, 1]);
let rolmed = ca
let rol_med = ca
.rolling_median(RollingOptions {
window_size: 2,
min_periods: 1,
..Default::default()
})
.unwrap();

let rol_med_weighted = ca
.rolling_median(RollingOptions {
window_size: 2,
min_periods: 1,
weights: Some(vec![1.0, 2.0]),
..Default::default()
})
.unwrap();
Expand Down Expand Up @@ -698,19 +693,29 @@ mod test {
)
.unwrap();

assert_eq!(*rolmed.dtype(), DataType::Float64);
assert_eq!(*rol_med.dtype(), DataType::Float64);
assert_eq!(*rol_med_weighted.dtype(), DataType::Float64);
assert_eq!(*rol_quantile.dtype(), DataType::Float64);
assert_eq!(*rol_quantile_weighted.dtype(), DataType::Float64);

let ca = Float32Chunked::new("foo", &[1.0, 2.0, 3.0, 2.0, 1.0]);
let rolmed = ca
let rol_med = ca
.rolling_median(RollingOptions {
window_size: 2,
min_periods: 1,
..Default::default()
})
.unwrap();

let rol_med_weighted = ca
.rolling_median(RollingOptions {
window_size: 2,
min_periods: 1,
weights: Some(vec![1.0, 2.0]),
..Default::default()
})
.unwrap();

let rol_quantile = ca
.rolling_quantile(
0.3,
Expand All @@ -736,15 +741,25 @@ mod test {
)
.unwrap();

assert_eq!(*rolmed.dtype(), DataType::Float32);
assert_eq!(*rol_med.dtype(), DataType::Float32);
assert_eq!(*rol_med_weighted.dtype(), DataType::Float32);
assert_eq!(*rol_quantile.dtype(), DataType::Float32);
assert_eq!(*rol_quantile_weighted.dtype(), DataType::Float32);

let ca = Float64Chunked::new("foo", &[1.0, 2.0, 3.0, 2.0, 1.0]);
let rolmed = ca
let rol_med = ca
.rolling_median(RollingOptions {
window_size: 2,
min_periods: 1,
..Default::default()
})
.unwrap();

let rol_med_weighted = ca
.rolling_median(RollingOptions {
window_size: 2,
min_periods: 1,
weights: Some(vec![1.0, 2.0]),
..Default::default()
})
.unwrap();
Expand Down Expand Up @@ -774,7 +789,8 @@ mod test {
)
.unwrap();

assert_eq!(*rolmed.dtype(), DataType::Float64);
assert_eq!(*rol_med.dtype(), DataType::Float64);
assert_eq!(*rol_med_weighted.dtype(), DataType::Float64);
assert_eq!(*rol_quantile.dtype(), DataType::Float64);
assert_eq!(*rol_quantile_weighted.dtype(), DataType::Float64);
}
Expand Down

0 comments on commit 4380682

Please sign in to comment.