Skip to content

Commit

Permalink
ewma null implementation (#2166)
Browse files Browse the repository at this point in the history
  • Loading branch information
ritchie46 committed Dec 25, 2021
1 parent 5426765 commit 174f7eb
Show file tree
Hide file tree
Showing 3 changed files with 248 additions and 26 deletions.
198 changes: 189 additions & 9 deletions polars/polars-arrow/src/kernels/ewm/average.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ use std::ops::AddAssign;
// https://github.com/pola-rs/polars/issues/2148
// https://stackoverflow.com/a/51392341/6717054

// this is the adjusted variant in pandas
pub fn ewma_no_nulls<T, I>(vals: I, alpha: T) -> Vec<T>
where
T: Float + AddAssign,
Expand All @@ -25,18 +26,58 @@ where
let mut ewma_old = first;
let one_sub_alpha = T::one() - alpha;

for (i, val) in iter.enumerate() {
let i = i + 1;
weight += one_sub_alpha.powf(T::from(i).unwrap());
let mut i = T::from(out.len()).unwrap();
for val in iter {
weight += one_sub_alpha.powf(i);
ewma_old = ewma_old * (one_sub_alpha) + val;
// Safety:
// we allocated vals.len()
unsafe { out.push_unchecked(ewma_old / weight) }
unsafe { out.push_unchecked(ewma_old / weight) };
i += T::one();
}

out
}

pub fn ewma<T, I>(vals: I, alpha: T) -> (usize, Vec<T>)
where
T: Float + AddAssign,
I: IntoIterator<Item = Option<T>>,
I::IntoIter: TrustedLen,
{
let mut iter = vals.into_iter();
let len = iter.size_hint().1.unwrap();
if len == 0 {
return (0, vec![]);
}
let mut weight = T::one();
let mut out = Vec::with_capacity(len);

let leading_null_count = set_first_none_null(&mut iter, &mut out);
let mut ewma_old = out[out.len() - 1];
let one_sub_alpha = T::one() - alpha;

let mut i = T::one();
let mut prev = out[out.len() - 1];
for opt_val in iter {
prev = match opt_val {
Some(val) => {
weight += one_sub_alpha.powf(i);
ewma_old = ewma_old * (one_sub_alpha) + val;
i += T::one();
ewma_old / weight
}
None => prev,
};
// Safety:
// we allocated vals.len()
unsafe { out.push_unchecked(prev) };
}

(leading_null_count, out)
}

// this is the non-adjusted variant in pandas
pub fn ewma_inf_hist_no_nulls<T, I>(vals: I, alpha: T) -> Vec<T>
where
T: Float + AddAssign + Debug,
Expand All @@ -54,19 +95,88 @@ where
out.push(first);
let one_sub_alpha = T::one() - alpha;

for (i, val) in iter.enumerate() {
let i = i + 1;
let mut prev = out[0];
for val in iter {
let output_val = val * alpha + prev * one_sub_alpha;
prev = output_val;

// Safety:
// we add first, so i - 1 always exits
let output_val = val * alpha + unsafe { *out.get_unchecked(i - 1) } * one_sub_alpha;
// we allocated vals.len()
unsafe { out.push_unchecked(output_val) }
}

out
}

// this is the non-adjusted variant in pandas
/// # Arguments
///
/// * `vals` - Iterator of optional values
/// * `alpha` - Smoothing factor
///
/// Returns the a tuple with:
/// * `leading_null_count` - the amount of nulls that must be set by the caller
/// * `smoothed values` - The result of the ewma
///
pub fn ewma_inf_hists<T, I>(vals: I, alpha: T) -> (usize, Vec<T>)
where
T: Float + AddAssign + Debug,
I: IntoIterator<Item = Option<T>>,
I::IntoIter: TrustedLen,
{
let mut iter = vals.into_iter();
let len = iter.size_hint().1.unwrap();
if len == 0 {
return (0, vec![]);
}

let mut out = Vec::with_capacity(len);

let leading_null_count = set_first_none_null(&mut iter, &mut out);
let one_sub_alpha = T::one() - alpha;
let mut prev = out[out.len() - 1];

for opt_val in iter {
let output_val = match opt_val {
Some(val) => {
// Safety:
// we add first, so i - 1 always exits
let output = val * alpha + prev * one_sub_alpha;
prev = output;
prev
}
None => prev,
};

// Safety:
// we allocated vals.len()
unsafe { out.push_unchecked(output_val) }
}

out
(leading_null_count, out)
}

pub fn set_first_none_null<T, I>(iter: &mut I, out: &mut Vec<T>) -> usize
where
T: Float + AddAssign,
I: Iterator<Item = Option<T>>,
{
let mut leading_null_count = 0;
// find first non null
for opt_val in iter {
match opt_val {
// these will be later masked out by the validity
None => {
leading_null_count += 1;
unsafe { out.push_unchecked(T::zero()) };
}
Some(val) => {
unsafe { out.push_unchecked(val) };
break;
}
}
}
leading_null_count
}

#[cfg(test)]
Expand All @@ -85,4 +195,74 @@ mod test {
let expected = [2.0, 3.5, 3.25];
assert_eq!(out, expected);
}

#[test]
fn test_ewma_null() {
let vals = &[
Some(2.0),
Some(3.0),
Some(5.0),
Some(7.0),
None,
None,
None,
Some(4.0),
];
let (cnt, out) = ewma_inf_hists(vals.into_iter().copied(), 0.5);
assert_eq!(cnt, 0);
let expected = [2.0, 2.5, 3.75, 5.375, 5.375, 5.375, 5.375, 4.6875];
assert_eq!(out, expected);
let vals = &[
None,
None,
Some(5.0),
Some(7.0),
None,
Some(2.0),
Some(1.0),
Some(4.0),
];
let (cnt, out) = ewma_inf_hists(vals.into_iter().copied(), 0.5);
let expected = [0.0, 0.0, 5.0, 6.0, 6.0, 4.0, 2.5, 3.25];
assert_eq!(cnt, 2);
assert_eq!(out, expected);

let (cnt, out) = ewma(vals.into_iter().copied(), 0.5);
let expected = [
0.0,
0.0,
5.0,
6.333333333333333,
6.333333333333333,
3.857142857142857,
2.3333333333333335,
3.193548387096774,
];
assert_eq!(cnt, 2);
assert_eq!(out, expected);

let vals = &[
None,
Some(1.0),
Some(5.0),
Some(7.0),
None,
Some(2.0),
Some(1.0),
Some(4.0),
];
let (cnt, out) = ewma(vals.into_iter().copied(), 0.5);
let expected = [
0.0,
1.0,
3.6666666666666665,
5.571428571428571,
5.571428571428571,
3.6666666666666665,
2.2903225806451615,
3.1587301587301586,
];
assert_eq!(cnt, 1);
assert_eq!(out, expected);
}
}
54 changes: 37 additions & 17 deletions polars/polars-core/src/series/ops/ewm.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,17 @@ use crate::prelude::*;
use arrow::bitmap::MutableBitmap;
use arrow::types::NativeType;
pub use polars_arrow::kernels::ewm::EWMOptions;
use polars_arrow::kernels::ewm::{ewma_inf_hist_no_nulls, ewma_no_nulls};
use polars_arrow::kernels::ewm::{ewma, ewma_inf_hist_no_nulls, ewma_inf_hists, ewma_no_nulls};
use polars_arrow::prelude::FromData;
use std::convert::TryFrom;

fn prepare_primitive_array<T: NativeType>(vals: Vec<T>, min_periods: usize) -> PrimitiveArray<T> {
if min_periods > 1 {
fn prepare_primitive_array<T: NativeType>(
vals: Vec<T>,
min_periods: usize,
leading_nulls: usize,
) -> PrimitiveArray<T> {
let leading = std::cmp::max(min_periods, leading_nulls);
if leading > 1 {
let mut validity = MutableBitmap::with_capacity(vals.len());
validity.extend_constant(min_periods, false);
validity.extend_constant(vals.len() - min_periods, true);
Expand All @@ -20,15 +25,8 @@ fn prepare_primitive_array<T: NativeType>(vals: Vec<T>, min_periods: usize) -> P

impl Series {
pub fn ewm_mean(&self, options: EWMOptions) -> Result<Self> {
if self.null_count() > 0 {
return self
.fill_null(FillNullStrategy::Zero)
.unwrap()
.ewm_mean(options);
}

match self.dtype() {
DataType::Float32 => {
match (self.dtype(), self.null_count()) {
(DataType::Float32, 0) => {
let ca = self.f32().unwrap();
match self.n_chunks() {
1 => {
Expand All @@ -39,7 +37,7 @@ impl Series {
} else {
ewma_inf_hist_no_nulls(vals.iter().copied(), options.alpha as f32)
};
let arr = prepare_primitive_array(out, options.min_periods);
let arr = prepare_primitive_array(out, options.min_periods, 0);
Series::try_from((self.name(), Arc::new(arr) as ArrayRef))
}
_ => {
Expand All @@ -49,12 +47,12 @@ impl Series {
} else {
ewma_inf_hist_no_nulls(iter, options.alpha as f32)
};
let arr = prepare_primitive_array(out, options.min_periods);
let arr = prepare_primitive_array(out, options.min_periods, 0);
Series::try_from((self.name(), Arc::new(arr) as ArrayRef))
}
}
}
DataType::Float64 => {
(DataType::Float64, 0) => {
let ca = self.f64().unwrap();
match self.n_chunks() {
1 => {
Expand All @@ -65,7 +63,7 @@ impl Series {
} else {
ewma_inf_hist_no_nulls(vals.iter().copied(), options.alpha)
};
let arr = prepare_primitive_array(out, options.min_periods);
let arr = prepare_primitive_array(out, options.min_periods, 0);
Series::try_from((self.name(), Arc::new(arr) as ArrayRef))
}
_ => {
Expand All @@ -75,11 +73,33 @@ impl Series {
} else {
ewma_inf_hist_no_nulls(iter, options.alpha)
};
let arr = prepare_primitive_array(out, options.min_periods);
let arr = prepare_primitive_array(out, options.min_periods, 0);
Series::try_from((self.name(), Arc::new(arr) as ArrayRef))
}
}
}
(DataType::Float32, _) => {
let ca = self.f32().unwrap();
let iter = ca.into_iter();
let (leading_nulls, out) = if options.adjust {
ewma(iter, options.alpha as f32)
} else {
ewma_inf_hists(iter, options.alpha as f32)
};
let arr = prepare_primitive_array(out, options.min_periods, leading_nulls);
Series::try_from((self.name(), Arc::new(arr) as ArrayRef))
}
(DataType::Float64, _) => {
let ca = self.f64().unwrap();
let iter = ca.into_iter();
let (leading_nulls, out) = if options.adjust {
ewma(iter, options.alpha as f64)
} else {
ewma_inf_hists(iter, options.alpha)
};
let arr = prepare_primitive_array(out, options.min_periods, leading_nulls);
Series::try_from((self.name(), Arc::new(arr) as ArrayRef))
}
_ => self.cast(&DataType::Float64)?.ewm_mean(options),
}
}
Expand Down
22 changes: 22 additions & 0 deletions py-polars/tests/test_series.py
Original file line number Diff line number Diff line change
Expand Up @@ -1321,3 +1321,25 @@ def test_ewm() -> None:
verify_series_and_expr_api(
a, expected, "ewm_mean", alpha=0.5, adjust=True, min_periods=3
)

a = pl.Series("a", [None, 1.0, 5.0, 7.0, None, 2.0, 5.0, 4])
expected = pl.Series(
"a",
[
None,
1.0,
3.6666666666666665,
5.571428571428571,
5.571428571428571,
3.6666666666666665,
4.354838709677419,
4.174603174603175,
],
)
verify_series_and_expr_api(
a, expected, "ewm_mean", alpha=0.5, adjust=True, min_periods=1
)
expected = pl.Series("a", [None, 1.0, 3.0, 5.0, 5.0, 3.5, 4.25, 4.125])
verify_series_and_expr_api(
a, expected, "ewm_mean", alpha=0.5, adjust=False, min_periods=1
)

0 comments on commit 174f7eb

Please sign in to comment.