Skip to content

Commit

Permalink
initial emwa implementation (#2150)
Browse files Browse the repository at this point in the history
  • Loading branch information
ritchie46 committed Dec 24, 2021
1 parent 112d159 commit 9e6fe01
Show file tree
Hide file tree
Showing 18 changed files with 324 additions and 1 deletion.
1 change: 1 addition & 0 deletions polars/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,7 @@ true_div = ["polars-lazy/true_div"]
diagonal_concat = ["polars-core/diagonal_concat"]
abs = ["polars-core/abs", "polars-lazy/abs"]
dynamic_groupby = ["polars-core/dynamic_groupby", "polars-lazy/dynamic_groupby"]
ewma = ["polars-core/ewma", "polars-lazy/ewma"]

# don't use this
private = ["polars-lazy/private"]
Expand Down
88 changes: 88 additions & 0 deletions polars/polars-arrow/src/kernels/ew/average.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,88 @@
use crate::trusted_len::{PushUnchecked, TrustedLen};
use num::Float;
use std::fmt::Debug;
use std::ops::AddAssign;
// See:
// https://github.com/pola-rs/polars/issues/2148
// https://stackoverflow.com/a/51392341/6717054

pub fn ewma_no_nulls<T, I>(vals: I, alpha: T) -> Vec<T>
where
T: Float + AddAssign,
I: IntoIterator<Item = T>,
I::IntoIter: TrustedLen,
{
let mut iter = vals.into_iter();
let len = iter.size_hint().1.unwrap();
if len == 0 {
return vec![];
}
let mut weight = T::one();
let mut out = Vec::with_capacity(len);

let first = iter.next().unwrap();
out.push(first);
let mut ewma_old = first;
let one_sub_alpha = T::one() - alpha;

for (i, val) in iter.enumerate() {
let i = i + 1;
weight += one_sub_alpha.powf(T::from(i).unwrap());
ewma_old = ewma_old * (one_sub_alpha) + val;
// Safety:
// we allocated vals.len()
unsafe { out.push_unchecked(ewma_old / weight) }
}

out
}

pub fn ewma_inf_hist_no_nulls<T, I>(vals: I, alpha: T) -> Vec<T>
where
T: Float + AddAssign + Debug,
I: IntoIterator<Item = T>,
I::IntoIter: TrustedLen,
{
let mut iter = vals.into_iter();
let len = iter.size_hint().1.unwrap();
if len == 0 {
return vec![];
}

let mut out = Vec::with_capacity(len);
let first = iter.next().unwrap();
out.push(first);
let one_sub_alpha = T::one() - alpha;

for (i, val) in iter.enumerate() {
let i = i + 1;

// Safety:
// we add first, so i - 1 always exits
let output_val = val * alpha + unsafe { *out.get_unchecked(i - 1) } * one_sub_alpha;

// Safety:
// we allocated vals.len()
unsafe { out.push_unchecked(output_val) }
}

out
}

#[cfg(test)]
mod test {
use super::*;

#[test]
fn test_ewma() {
let vals = [2.0, 5.0, 3.0];
let out = ewma_no_nulls(vals.iter().copied(), 0.5);
let expected = [2.0, 4.0, 3.4285714285714284];
assert_eq!(out, expected);

let vals = [2.0, 5.0, 3.0];
let out = ewma_inf_hist_no_nulls(vals.iter().copied(), 0.5);
let expected = [2.0, 3.5, 3.25];
assert_eq!(out, expected);
}
}
42 changes: 42 additions & 0 deletions polars/polars-arrow/src/kernels/ew/mod.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
mod average;

pub use average::*;

#[derive(Debug, Copy, Clone)]
pub struct ExponentialWindowOptions {
pub alpha: f64,
pub adjust: bool,
}

impl Default for ExponentialWindowOptions {
fn default() -> Self {
Self {
alpha: 0.5,
adjust: true,
}
}
}

impl ExponentialWindowOptions {
pub fn and_adjust(mut self, adjust: bool) -> Self {
self.adjust = adjust;
self
}
pub fn and_span(mut self, span: usize) -> Self {
assert!(span >= 1);
self.alpha = 2.0 / (span as f64 + 1.0);
self
}

pub fn and_halflife(mut self, halflife: f64) -> Self {
assert!(halflife > 0.0);
self.alpha = 1.0 - ((-2.0f64).ln() / halflife).exp();
self
}

pub fn and_com(mut self, com: f64) -> Self {
assert!(com > 0.0);
self.alpha = 1.0 / (1.0 + com);
self
}
}
1 change: 1 addition & 0 deletions polars/polars-arrow/src/kernels/mod.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
use arrow::array::BooleanArray;
use arrow::bitmap::utils::BitChunks;
use std::iter::Enumerate;
pub mod ew;
pub mod float;
pub mod list;
pub mod rolling;
Expand Down
1 change: 1 addition & 0 deletions polars/polars-core/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,7 @@ diff = []
moment = []
diagonal_concat = []
abs = []
ewma = []

dynamic_groupby = ["polars-time", "dtype-datetime", "dtype-date"]

Expand Down
3 changes: 3 additions & 0 deletions polars/polars-core/src/prelude.rs
Original file line number Diff line number Diff line change
Expand Up @@ -49,3 +49,6 @@ pub use crate::chunked_array::ops::rolling_window::RollingOptions;

#[cfg(feature = "dynamic_groupby")]
pub use polars_time::{groupby::ClosedWindow, Duration};

#[cfg(feature = "ewma")]
pub use polars_arrow::kernels::ew::ExponentialWindowOptions;
66 changes: 66 additions & 0 deletions polars/polars-core/src/series/ops/ewm.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
use crate::prelude::*;
pub use polars_arrow::kernels::ew::ExponentialWindowOptions;
use polars_arrow::kernels::ew::{ewma_inf_hist_no_nulls, ewma_no_nulls};

impl Series {
pub fn ewm_mean(&self, options: ExponentialWindowOptions) -> Result<Self> {
if self.null_count() > 0 {
return self
.fill_null(FillNullStrategy::Zero)
.unwrap()
.ewm_mean(options);
}

match self.dtype() {
DataType::Float32 => {
let ca = self.f32().unwrap();
match self.n_chunks() {
1 => {
let vals = ca.downcast_iter().next().unwrap();
let vals = vals.values().as_slice();
let out = if options.adjust {
ewma_no_nulls(vals.iter().copied(), options.alpha as f32)
} else {
ewma_inf_hist_no_nulls(vals.iter().copied(), options.alpha as f32)
};
Ok(Float32Chunked::new_vec(self.name(), out).into_series())
}
_ => {
let iter = ca.into_no_null_iter();
let out = if options.adjust {
ewma_no_nulls(iter, options.alpha as f32)
} else {
ewma_inf_hist_no_nulls(iter, options.alpha as f32)
};
Ok(Float32Chunked::new_vec(self.name(), out).into_series())
}
}
}
DataType::Float64 => {
let ca = self.f64().unwrap();
match self.n_chunks() {
1 => {
let vals = ca.downcast_iter().next().unwrap();
let vals = vals.values().as_slice();
let out = if options.adjust {
ewma_no_nulls(vals.iter().copied(), options.alpha)
} else {
ewma_inf_hist_no_nulls(vals.iter().copied(), options.alpha)
};
Ok(Float64Chunked::new_vec(self.name(), out).into_series())
}
_ => {
let iter = ca.into_no_null_iter();
let out = if options.adjust {
ewma_no_nulls(iter, options.alpha)
} else {
ewma_inf_hist_no_nulls(iter, options.alpha)
};
Ok(Float64Chunked::new_vec(self.name(), out).into_series())
}
}
}
_ => self.cast(&DataType::Float64)?.ewm_mean(options),
}
}
}
2 changes: 2 additions & 0 deletions polars/polars-core/src/series/ops/mod.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
#[cfg(feature = "diff")]
pub mod diff;
#[cfg(feature = "ewma")]
mod ewm;
#[cfg(feature = "moment")]
pub mod moment;
mod null;
Expand Down
1 change: 1 addition & 0 deletions polars/polars-lazy/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@ list = ["polars-core/list"]
abs = ["polars-core/abs"]
random = ["polars-core/random"]
dynamic_groupby = ["polars-core/dynamic_groupby"]
ewma = ["polars-core/ewma"]

# no guarantees whatsoever
private = []
Expand Down
12 changes: 12 additions & 0 deletions polars/polars-lazy/src/dsl.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1822,6 +1822,18 @@ impl Expr {
pub fn shuffle(self, seed: u64) -> Self {
self.apply(move |s| Ok(s.shuffle(seed)), GetOutput::same_type())
}

#[cfg(feature = "ewma")]
pub fn ewm_mean(self, options: ExponentialWindowOptions) -> Self {
use DataType::*;
self.apply(
move |s| s.ewm_mean(options),
GetOutput::map_dtype(|dt| match dt {
Float64 | Float32 => dt.clone(),
_ => Float64,
}),
)
}
}

/// Create a Column Expression based on a column name.
Expand Down
7 changes: 6 additions & 1 deletion polars/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -113,12 +113,15 @@
//! * gzip
//!
//! * `DataFrame` operations:
//! - `dynamic_groupby` - Groupby based on a time window instead of predefined keys.
//! - `pivot` - [pivot operation](crate::frame::groupby::GroupBy::pivot) on `DataFrame`s
//! - `sort_multiple` - Allow sorting a `DataFrame` on multiple columns
//! - `rows` - Create `DataFrame` from rows and extract rows from `DataFrames`.
//! - `asof_join` - Join as of, to join on nearest keys instead of exact equality match.
//! - `cross_join` - Create the cartesian product of two DataFrames.
//! - `groupby_list` - Allow groupby operation on keys of type List.
//! - `row_hash` - Utility to hash DataFrame rows to UInt64Chunked
//! - `diagonal_concat` - Concat diagonally thereby combining different schemas.
//! * `Series` operations:
//! - `is_in` - [Check for membership in `Series`](crate::chunked_array::ops::IsIn)
//! - `zip_with` - [Zip two Series/ ChunkedArrays](crate::chunked_array::ops::ChunkZip)
Expand All @@ -139,10 +142,12 @@
//! - `list` - [List utils](crate::chunked_array::list::namespace)
//! - `rank` - Ranking algorithms.
//! - `moment` - kurtosis and skew statistics
//! - `ewma` - Exponential moving average windows
//! - `abs` - Get absolute values of Series
//! - `arange` - Range operation on Series
//! * `DataFrame` pretty printing (Choose one or none, but not both):
//! - `plain_fmt` - no overflowing (less compilation times)
//! - `pretty_fmt` - cell overflow (increased compilation times)
//! - `row_hash` - Utility to hash DataFrame rows to UInt64Chunked
//!
//! ## Compile times and opt-in data types
//! As mentioned above, Polars `Series` are wrappers around
Expand Down
1 change: 1 addition & 0 deletions py-polars/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,7 @@ features = [
"dtype-categorical",
"diagonal_concat",
"abs",
"ewma",
]

# [patch.crates-io]
Expand Down
1 change: 1 addition & 0 deletions py-polars/docs/source/reference/expression.rst
Original file line number Diff line number Diff line change
Expand Up @@ -134,6 +134,7 @@ Computations
Expr.rolling_median
Expr.rolling_quantile
Expr.rolling_skew
Expr.ewm_mean
Expr.hash
Expr.abs
Expr.rank
Expand Down
1 change: 1 addition & 0 deletions py-polars/docs/source/reference/series.rst
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,7 @@ Computations
Series.rolling_median
Series.rolling_quantile
Series.rolling_skew
Series.ewm_mean
Series.hash
Series.peak_max
Series.peak_min
Expand Down
43 changes: 43 additions & 0 deletions py-polars/polars/internals/expr.py
Original file line number Diff line number Diff line change
Expand Up @@ -2088,6 +2088,49 @@ def shuffle(self, seed: int = 0) -> "Expr":
"""
return wrap_expr(self._pyexpr.shuffle(seed))

def ewm_mean(
self,
com: Optional[float] = None,
span: Optional[float] = None,
halflife: Optional[float] = None,
alpha: Optional[float] = None,
adjust: bool = True,
) -> "Expr":
r"""
Exponential moving average. Null values are replaced with 0.0.
Parameters
----------
com
Specify decay in terms of center of mass, :math:`alpha = 1/(1 + com) \;for\; com >= 0`.
span
Specify decay in terms of span, :math:`alpha = 2/(span + 1) \;for\; span >= 1`
halflife
Specify decay in terms of half-life, :math:`alpha = 1 - exp(-ln(2) / halflife) \;for\; halflife > 0`
alpha
Specify smoothing factor alpha directly, :math:`0 < alpha < 1`.
adjust
Divide by decaying adjustment factor in beginning periods to account for imbalance in relative weightings
- When adjust = True the EW function is calculated using weights :math:`w_i = (1 - alpha)^i`
- When adjust = False the EW function is calculated recursively.
"""
if com is not None and alpha is not None:
assert com >= 0.0
alpha = 1.0 / (1.0 + com)
if span is not None and alpha is not None:
assert span >= 1.0
alpha = 2.0 / (span + 1.0)
if halflife is not None and alpha is not None:
assert halflife > 0.0
alpha = 1.0 - np.exp(-np.log(2.0) / halflife)
if alpha is None:
raise ValueError(
"at least one of {com, span, halflife, alpha} should be set"
)
return wrap_expr(self._pyexpr.ewm_mean(alpha, adjust))

# Below are the namespaces defined. Keep these at the end of the definition of Expr, as to not confuse mypy with
# the type annotation `str` with the namespace "str"

Expand Down

0 comments on commit 9e6fe01

Please sign in to comment.