Skip to content

Commit

Permalink
fix rolling_mean on integers: closes #1411
Browse files Browse the repository at this point in the history
  • Loading branch information
ritchie46 committed Sep 22, 2021
1 parent 0621721 commit 2c9d59f
Show file tree
Hide file tree
Showing 4 changed files with 82 additions and 37 deletions.
36 changes: 20 additions & 16 deletions polars/polars-core/src/chunked_array/ops/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -121,13 +121,12 @@ pub trait ChunkBytes {
fn to_byte_slices(&self) -> Vec<&[u8]>;
}

/// Rolling window functions
#[cfg(feature = "rolling_window")]
pub trait ChunkWindow {
/// apply a rolling sum (moving sum) over the values in this array.
/// a window of length `window_size` will traverse the array. the values that fill this window
/// will (optionally) be multiplied with the weights given by the `weight` vector. the resulting
/// values will be aggregated to their sum.
pub trait ChunkWindowMean {
/// Apply a rolling mean (moving mean) over the values in this array.
/// A window of length `window_size` will traverse the array. The values that fill this window
/// will (optionally) be multiplied with the weights given by the `weight` vector. The resulting
/// values will be aggregated to their mean.
///
/// # Arguments
///
Expand All @@ -137,24 +136,30 @@ pub trait ChunkWindow {
/// * `ignore_null` - Toggle behavior of aggregation regarding null values in the window.
/// `true` -> Null values will be ignored.
/// `false` -> Any Null in the window leads to a Null in the aggregation result.
fn rolling_sum(
/// * `min_periods` - Amount of elements in the window that should be filled before computing a result.
fn rolling_mean(
&self,
_window_size: u32,
_weight: Option<&[f64]>,
_ignore_null: bool,
_min_periods: u32,
) -> Result<Self>
) -> Result<Series>
where
Self: std::marker::Sized,
{
Err(PolarsError::InvalidOperation(
"rolling sum not supported for this datatype".into(),
"rolling mean not supported for this datatype".into(),
))
}
/// Apply a rolling mean (moving mean) over the values in this array.
/// A window of length `window_size` will traverse the array. The values that fill this window
/// will (optionally) be multiplied with the weights given by the `weight` vector. The resulting
/// values will be aggregated to their mean.
}

/// Rolling window functions
#[cfg(feature = "rolling_window")]
pub trait ChunkWindow {
/// apply a rolling sum (moving sum) over the values in this array.
/// a window of length `window_size` will traverse the array. the values that fill this window
/// will (optionally) be multiplied with the weights given by the `weight` vector. the resulting
/// values will be aggregated to their sum.
///
/// # Arguments
///
Expand All @@ -164,8 +169,7 @@ pub trait ChunkWindow {
/// * `ignore_null` - Toggle behavior of aggregation regarding null values in the window.
/// `true` -> Null values will be ignored.
/// `false` -> Any Null in the window leads to a Null in the aggregation result.
/// * `min_periods` - Amount of elements in the window that should be filled before computing a result.
fn rolling_mean(
fn rolling_sum(
&self,
_window_size: u32,
_weight: Option<&[f64]>,
Expand All @@ -176,7 +180,7 @@ pub trait ChunkWindow {
Self: std::marker::Sized,
{
Err(PolarsError::InvalidOperation(
"rolling mean not supported for this datatype".into(),
"rolling sum not supported for this datatype".into(),
))
}

Expand Down
63 changes: 49 additions & 14 deletions polars/polars-core/src/chunked_array/ops/rolling_window.rs
Original file line number Diff line number Diff line change
Expand Up @@ -224,6 +224,44 @@ where
}
}

impl<T> ChunkWindowMean for ChunkedArray<T>
where
T: PolarsNumericType,
T::Native: Add<Output = T::Native>
+ Sub<Output = T::Native>
+ Mul<Output = T::Native>
+ Div<Output = T::Native>
+ Rem<Output = T::Native>
+ Zero
+ Bounded
+ NumCast
+ PartialOrd
+ One
+ Copy,
ChunkedArray<T>: IntoSeries,
{
fn rolling_mean(
&self,
window_size: u32,
weight: Option<&[f64]>,
ignore_null: bool,
min_periods: u32,
) -> Result<Series> {
match self.dtype() {
DataType::Float32 | DataType::Float64 => {
check_input(window_size, min_periods)?;
let ca = self.rolling_sum(window_size, weight, ignore_null, min_periods)?;
let rolling_window_size = self.window_size(window_size, None, min_periods);
Ok((&ca).div(&rolling_window_size).into_series())
}
_ => {
let ca = self.cast::<Float64Type>()?;
ca.rolling_mean(window_size, weight, ignore_null, min_periods)
}
}
}
}

#[derive(Clone, Copy)]
pub enum InitFold {
Zero,
Expand Down Expand Up @@ -270,19 +308,6 @@ where
))
}

fn rolling_mean(
&self,
window_size: u32,
weight: Option<&[f64]>,
ignore_null: bool,
min_periods: u32,
) -> Result<Self> {
check_input(window_size, min_periods)?;
let rolling_window_size = self.window_size(window_size, None, min_periods);
let ca = self.rolling_sum(window_size, weight, ignore_null, min_periods)?;
Ok((&ca).div(&rolling_window_size))
}

fn rolling_min(
&self,
window_size: u32,
Expand Down Expand Up @@ -646,8 +671,9 @@ mod test {

// validate that we divide by the proper window length. (same as pandas)
let a = ca.rolling_mean(3, None, true, 1).unwrap();
let a = a.f64().unwrap();
assert_eq!(
Vec::from(&a),
Vec::from(a),
&[
Some(0.0),
Some(0.5),
Expand All @@ -658,6 +684,15 @@ mod test {
Some(5.5)
]
);

// integers
let ca = Int32Chunked::new_from_slice("", &[1, 8, 6, 2, 16, 10]);
let out = ca.rolling_mean(2, None, true, 2).unwrap();
let out = out.f64().unwrap();
assert_eq!(
Vec::from(out),
&[None, Some(4.5), Some(7.0), Some(4.0), Some(9.0), Some(13.0),]
);
}

#[test]
Expand Down
10 changes: 8 additions & 2 deletions polars/polars-core/src/series/implementations/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -81,8 +81,14 @@ macro_rules! impl_dyn_series {
ignore_null: bool,
min_periods: u32,
) -> Result<Series> {
ChunkWindow::rolling_mean(&self.0, window_size, weight, ignore_null, min_periods)
.map(|ca| ca.into_series())
ChunkWindowMean::rolling_mean(
&self.0,
window_size,
weight,
ignore_null,
min_periods,
)
.map(|ca| ca.into_series())
}
#[cfg(feature = "rolling_window")]
fn _rolling_sum(
Expand Down
10 changes: 5 additions & 5 deletions py-polars/tests/test_lazy.py
Original file line number Diff line number Diff line change
Expand Up @@ -416,16 +416,16 @@ def test_rolling(fruits_cars):
df = fruits_cars
assert df.select(
[
col("A").rolling_min(3, min_periods=1).alias("1"),
col("A").rolling_mean(3, min_periods=1).alias("2"),
col("A").rolling_max(3, min_periods=1).alias("3"),
col("A").rolling_sum(3, min_periods=1).alias("4"),
pl.col("A").rolling_min(3, min_periods=1).alias("1"),
pl.col("A").rolling_mean(3, min_periods=1).alias("2"),
pl.col("A").rolling_max(3, min_periods=1).alias("3"),
pl.col("A").rolling_sum(3, min_periods=1).alias("4"),
]
).frame_equal(
pl.DataFrame(
{
"1": [1, 1, 1, 2, 3],
"2": [1, 1, 2, 3, 4],
"2": [1.0, 1.5, 2.0, 3.0, 4.0],
"3": [1, 2, 3, 4, 5],
"5": [1, 3, 6, 9, 12],
}
Expand Down

0 comments on commit 2c9d59f

Please sign in to comment.