Skip to content

Commit

Permalink
fix[rust]: fix and test all rolling extrema paths (#4453)
Browse files Browse the repository at this point in the history
  • Loading branch information
ritchie46 committed Aug 17, 2022
1 parent 401b35e commit 221a468
Show file tree
Hide file tree
Showing 2 changed files with 52 additions and 3 deletions.
6 changes: 3 additions & 3 deletions polars/polars-arrow/src/kernels/rolling/no_nulls/min_max.rs
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ impl<'a, T: NativeType + IsFloat + PartialOrd> RollingAggWindowNoNulls<'a, T> fo
.iter()
.min_by(|a, b| compare_fn_nan_min(*a, *b))
.unwrap_or(
&self.slice[std::cmp::max(self.last_start, self.last_end.saturating_sub(1))],
&self.slice[std::cmp::min(self.last_start, self.last_end.saturating_sub(1))],
);

if recompute_min {
Expand Down Expand Up @@ -425,12 +425,12 @@ where
}
(false, None) => {
// will be O(n2)
if is_reverse_sorted_max(values) {
if is_sorted_min(values) {
rolling_apply_agg_window::<SortedMinMax<_>, _, _>(
values,
window_size,
min_periods,
det_offsets_center,
det_offsets,
)
} else {
rolling_apply_agg_window::<MinWindow<_>, _, _>(
Expand Down
49 changes: 49 additions & 0 deletions py-polars/tests/test_rolling.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,3 +82,52 @@ def test_rolling_skew() -> None:
0.16923763134384154,
]
)


def test_rolling_extrema() -> None:
# sorted data and nulls flags trigger different kernels
df = (
pl.DataFrame(
{
"col1": pl.arange(0, 7, eager=True),
"col2": pl.arange(0, 7, eager=True).reverse(),
}
)
).with_columns(
[
pl.when(pl.arange(0, pl.count(), eager=False) < 2)
.then(None)
.otherwise(pl.all())
.suffix("_nulls")
]
)

assert df.select([pl.all().rolling_min(3)]).to_dict(False) == {
"col1": [None, None, 0, 1, 2, 3, 4],
"col2": [None, None, 4, 3, 2, 1, 0],
"col1_nulls": [None, None, None, None, 2, 3, 4],
"col2_nulls": [None, None, None, None, 2, 1, 0],
}

assert df.select([pl.all().rolling_max(3)]).to_dict(False) == {
"col1": [None, None, 2, 3, 4, 5, 6],
"col2": [None, None, 6, 5, 4, 3, 2],
"col1_nulls": [None, None, None, None, 4, 5, 6],
"col2_nulls": [None, None, None, None, 4, 3, 2],
}

# shuffled data triggers other kernels
df = df.select([pl.all().shuffle(0)])
assert df.select([pl.all().rolling_min(3)]).to_dict(False) == {
"col1": [None, None, 0, 0, 1, 2, 2],
"col2": [None, None, 0, 2, 1, 1, 1],
"col1_nulls": [None, None, None, None, None, 2, 2],
"col2_nulls": [None, None, None, None, None, 1, 1],
}

assert df.select([pl.all().rolling_max(3)]).to_dict(False) == {
"col1": [None, None, 6, 4, 5, 5, 5],
"col2": [None, None, 6, 6, 5, 4, 4],
"col1_nulls": [None, None, None, None, None, 5, 5],
"col2_nulls": [None, None, None, None, None, 4, 4],
}

0 comments on commit 221a468

Please sign in to comment.