Skip to content

Commit

Permalink
add min_periods to ewma (#2165)
Browse files Browse the repository at this point in the history
  • Loading branch information
ritchie46 committed Dec 25, 2021
1 parent 0702f45 commit 5426765
Show file tree
Hide file tree
Showing 7 changed files with 77 additions and 21 deletions.
10 changes: 9 additions & 1 deletion polars/polars-arrow/src/array/default_arrays.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
use arrow::array::{BooleanArray, Utf8Array};
use arrow::array::{BooleanArray, PrimitiveArray, Utf8Array};
use arrow::bitmap::Bitmap;
use arrow::buffer::Buffer;
use arrow::datatypes::DataType;
use arrow::types::NativeType;

pub trait FromData<T> {
fn from_data_default(values: T, validity: Option<Bitmap>) -> Self;
Expand All @@ -13,6 +14,13 @@ impl FromData<Bitmap> for BooleanArray {
}
}

impl<T: NativeType> FromData<Buffer<T>> for PrimitiveArray<T> {
fn from_data_default(values: Buffer<T>, validity: Option<Bitmap>) -> Self {
let dt = T::PRIMITIVE;
PrimitiveArray::from_data(dt.into(), values, validity)
}
}

pub trait FromDataUtf8 {
/// # Safety
/// `values` buffer must contain valid utf8 between every `offset`
Expand Down
6 changes: 6 additions & 0 deletions polars/polars-arrow/src/kernels/ewm/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,18 +6,24 @@ pub use average::*;
pub struct EWMOptions {
pub alpha: f64,
pub adjust: bool,
pub min_periods: usize,
}

impl Default for EWMOptions {
fn default() -> Self {
Self {
alpha: 0.5,
adjust: true,
min_periods: 1,
}
}
}

impl EWMOptions {
pub fn and_min_periods(mut self, min_periods: usize) -> Self {
self.min_periods = min_periods;
self
}
pub fn and_adjust(mut self, adjust: bool) -> Self {
self.adjust = adjust;
self
Expand Down
28 changes: 24 additions & 4 deletions polars/polars-core/src/series/ops/ewm.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,22 @@
use crate::prelude::*;
use arrow::bitmap::MutableBitmap;
use arrow::types::NativeType;
pub use polars_arrow::kernels::ewm::EWMOptions;
use polars_arrow::kernels::ewm::{ewma_inf_hist_no_nulls, ewma_no_nulls};
use polars_arrow::prelude::FromData;
use std::convert::TryFrom;

fn prepare_primitive_array<T: NativeType>(vals: Vec<T>, min_periods: usize) -> PrimitiveArray<T> {
if min_periods > 1 {
let mut validity = MutableBitmap::with_capacity(vals.len());
validity.extend_constant(min_periods, false);
validity.extend_constant(vals.len() - min_periods, true);

PrimitiveArray::from_data_default(vals.into(), Some(validity.into()))
} else {
PrimitiveArray::from_data_default(vals.into(), None)
}
}

impl Series {
pub fn ewm_mean(&self, options: EWMOptions) -> Result<Self> {
Expand All @@ -23,7 +39,8 @@ impl Series {
} else {
ewma_inf_hist_no_nulls(vals.iter().copied(), options.alpha as f32)
};
Ok(Float32Chunked::new_vec(self.name(), out).into_series())
let arr = prepare_primitive_array(out, options.min_periods);
Series::try_from((self.name(), Arc::new(arr) as ArrayRef))
}
_ => {
let iter = ca.into_no_null_iter();
Expand All @@ -32,7 +49,8 @@ impl Series {
} else {
ewma_inf_hist_no_nulls(iter, options.alpha as f32)
};
Ok(Float32Chunked::new_vec(self.name(), out).into_series())
let arr = prepare_primitive_array(out, options.min_periods);
Series::try_from((self.name(), Arc::new(arr) as ArrayRef))
}
}
}
Expand All @@ -47,7 +65,8 @@ impl Series {
} else {
ewma_inf_hist_no_nulls(vals.iter().copied(), options.alpha)
};
Ok(Float64Chunked::new_vec(self.name(), out).into_series())
let arr = prepare_primitive_array(out, options.min_periods);
Series::try_from((self.name(), Arc::new(arr) as ArrayRef))
}
_ => {
let iter = ca.into_no_null_iter();
Expand All @@ -56,7 +75,8 @@ impl Series {
} else {
ewma_inf_hist_no_nulls(iter, options.alpha)
};
Ok(Float64Chunked::new_vec(self.name(), out).into_series())
let arr = prepare_primitive_array(out, options.min_periods);
Series::try_from((self.name(), Arc::new(arr) as ArrayRef))
}
}
}
Expand Down
5 changes: 4 additions & 1 deletion py-polars/polars/internals/expr.py
Original file line number Diff line number Diff line change
Expand Up @@ -2095,6 +2095,7 @@ def ewm_mean(
half_life: Optional[float] = None,
alpha: Optional[float] = None,
adjust: bool = True,
min_periods: int = 1,
) -> "Expr":
r"""
Exponential moving average. Null values are replaced with 0.0.
Expand All @@ -2114,6 +2115,8 @@ def ewm_mean(
- When adjust = True the EW function is calculated using weights :math:`w_i = (1 - alpha)^i`
- When adjust = False the EW function is calculated recursively.
min_periods
Minimum number of observations in window required to have a value (otherwise result is Null).
"""
if com is not None and alpha is not None:
Expand All @@ -2129,7 +2132,7 @@ def ewm_mean(
raise ValueError(
"at least one of {com, span, halflife, alpha} should be set"
)
return wrap_expr(self._pyexpr.ewm_mean(alpha, adjust))
return wrap_expr(self._pyexpr.ewm_mean(alpha, adjust, min_periods))

# Below are the namespaces defined. Keep these at the end of the definition of Expr, as to not confuse mypy with
# the type annotation `str` with the namespace "str"
Expand Down
9 changes: 8 additions & 1 deletion py-polars/polars/internals/series.py
Original file line number Diff line number Diff line change
Expand Up @@ -3101,6 +3101,7 @@ def ewm_mean(
half_life: Optional[float] = None,
alpha: Optional[float] = None,
adjust: bool = True,
min_periods: int = 1,
) -> "Series":
r"""
Exponential moving average. Null values are replaced with 0.0.
Expand All @@ -3120,11 +3121,17 @@ def ewm_mean(
- When adjust = True the EW function is calculated using weights :math:`w_i = (1 - alpha)^i`
- When adjust = False the EW function is calculated recursively.
min_periods
Minimum number of observations in window required to have a value (otherwise result is Null).
"""
return (
self.to_frame()
.select(pli.col(self.name).ewm_mean(com, span, half_life, alpha, adjust))
.select(
pli.col(self.name).ewm_mean(
com, span, half_life, alpha, adjust, min_periods
)
)
.to_series()
)

Expand Down
8 changes: 6 additions & 2 deletions py-polars/src/lazy/dsl.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1028,8 +1028,12 @@ impl PyExpr {
self.inner.clone().shuffle(seed).into()
}

pub fn ewm_mean(&self, alpha: f64, adjust: bool) -> Self {
let options = EWMOptions { alpha, adjust };
pub fn ewm_mean(&self, alpha: f64, adjust: bool, min_periods: usize) -> Self {
let options = EWMOptions {
alpha,
adjust,
min_periods,
};
self.inner.clone().ewm_mean(options).into()
}
}
Expand Down
32 changes: 20 additions & 12 deletions py-polars/tests/test_series.py
Original file line number Diff line number Diff line change
Expand Up @@ -1301,15 +1301,23 @@ def test_trigonometric(f: str) -> None:

def test_ewm() -> None:
a = pl.Series("a", [2, 5, 3])
assert a.ewm_mean(alpha=0.5, adjust=True).to_list() == [
2.0,
4.0,
3.4285714285714284,
]
assert a.ewm_mean(alpha=0.5, adjust=False).to_list() == [2.0, 3.5, 3.25]
assert pl.select(
pl.lit(a).ewm_mean(alpha=0.5, adjust=True)
).to_series().to_list() == [2.0, 4.0, 3.4285714285714284]
assert pl.select(
pl.lit(a).ewm_mean(alpha=0.5, adjust=False)
).to_series().to_list() == [2.0, 3.5, 3.25]
expected = pl.Series(
"a",
[
2.0,
4.0,
3.4285714285714284,
],
)
verify_series_and_expr_api(a, expected, "ewm_mean", alpha=0.5, adjust=True)
expected = pl.Series("a", [2.0, 3.5, 3.25])
verify_series_and_expr_api(a, expected, "ewm_mean", alpha=0.5, adjust=False)
a = pl.Series("a", [2, 3, 5, 7, 4])
expected = pl.Series("a", [None, 2.666667, 4.0, 5.6, 4.774194])
verify_series_and_expr_api(
a, expected, "ewm_mean", alpha=0.5, adjust=True, min_periods=2
)
expected = pl.Series("a", [None, None, 4.0, 5.6, 4.774194])
verify_series_and_expr_api(
a, expected, "ewm_mean", alpha=0.5, adjust=True, min_periods=3
)

0 comments on commit 5426765

Please sign in to comment.