Skip to content

Commit

Permalink
feat[rust, python]: introduce bias parameter to ewm_var and `ewm_…
Browse files Browse the repository at this point in the history
…mean` (#4636)
  • Loading branch information
matteosantama committed Sep 9, 2022
1 parent f1726fd commit faf7e60
Show file tree
Hide file tree
Showing 9 changed files with 325 additions and 178 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
*.iml
*.so
*.ipynb
.DS_Store
.ENV
.coverage
Expand Down
57 changes: 27 additions & 30 deletions polars/polars-arrow/src/kernels/ewm/average.rs
Original file line number Diff line number Diff line change
@@ -1,47 +1,44 @@
use std::ops::AddAssign;

use arrow::array::PrimitiveArray;
use arrow::types::NativeType;
use num::Float;

use crate::utils::CustomIterTools;

// See:
// https://github.com/pola-rs/polars/issues/2148
// https://stackoverflow.com/a/51392341/6717054

pub fn ewm_mean<T>(
xs: &PrimitiveArray<T>,
alpha: T,
adjust: bool,
min_periods: usize,
) -> PrimitiveArray<T>
where
T: Float + NativeType,
T: Float + NativeType + AddAssign,
{
let mut denom = T::zero();
let one_sub_alpha = T::one() - alpha;

let mut opt_mean = None;
let mut non_null_cnt = 0usize;

let mut opt_ewma = None;
let wgt = alpha;
let mut wgt_sum = if adjust { T::zero() } else { T::one() };

xs.iter()
.map(|opt_x| {
if let Some(&x) = opt_x {
non_null_cnt += 1;

let prev_ewma = opt_ewma.unwrap_or(x);
let prev_mean = opt_mean.unwrap_or(x);

wgt_sum = one_sub_alpha * wgt_sum + wgt;

let curr_mean = prev_mean + (x - prev_mean) * wgt / wgt_sum;

let value = if adjust {
let numer = prev_ewma * denom * one_sub_alpha + x;
denom = T::one() + one_sub_alpha * denom;
numer / denom
} else {
x * alpha + prev_ewma * one_sub_alpha
};
opt_ewma = Some(value);
opt_mean = Some(curr_mean);
}
match non_null_cnt < min_periods {
true => None,
false => opt_ewma,
false => opt_mean,
}
})
.collect_trusted()
Expand All @@ -62,8 +59,8 @@ mod test {
false => PrimitiveArray::from([Some(1.0f32), Some(1.5f32), Some(2.25f32)]),
true => PrimitiveArray::from([
Some(1.0f32),
Some(1.6666666666666667f32),
Some(2.4285714285714284f32),
Some(1.6666667f32), // <-- pandas: 1.66666667
Some(2.42857143),
]),
};
assert_eq!(result, expected);
Expand Down Expand Up @@ -132,11 +129,11 @@ mod test {
None,
None,
Some(5.0f32),
Some(6.333333333333333f32),
Some(6.333333333333333f32),
Some(3.857142857142857f32),
Some(2.3333333333333335f32),
Some(3.193548387096774f32),
Some(6.33333333f32),
Some(6.33333333f32),
Some(3.85714286f32),
Some(2.3333335f32), // <-- pandas: 2.33333333
Some(3.19354839f32),
]);
assert_eq!(adjusted_result, adjusted_expected);

Expand All @@ -154,12 +151,12 @@ mod test {
let expected = PrimitiveArray::from([
None,
Some(1.0f32),
Some(3.6666666666666665f32),
Some(5.571428571428571f32),
Some(5.571428571428571f32),
Some(3.6666666666666665f32),
Some(2.2903225806451615f32),
Some(3.1587301587301586f32),
Some(3.66666667f32),
Some(5.57142857f32),
Some(5.57142857f32),
Some(3.66666667),
Some(2.2903228f32), // <-- pandas: 2.29032258
Some(3.15873016f32),
]);
assert_eq!(result, expected);
}
Expand Down
2 changes: 2 additions & 0 deletions polars/polars-arrow/src/kernels/ewm/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ pub use variance::*;
pub struct EWMOptions {
pub alpha: f64,
pub adjust: bool,
pub bias: bool,
pub min_periods: usize,
}

Expand All @@ -17,6 +18,7 @@ impl Default for EWMOptions {
Self {
alpha: 0.5,
adjust: true,
bias: false,
min_periods: 1,
}
}
Expand Down

0 comments on commit faf7e60

Please sign in to comment.