Skip to content

Commit

Permalink
Fix Python EWM alpha (#2212)
Browse files Browse the repository at this point in the history
  • Loading branch information
mhconradt committed Dec 29, 2021
1 parent d6eeac2 commit 99a6893
Show file tree
Hide file tree
Showing 2 changed files with 9 additions and 8 deletions.
15 changes: 7 additions & 8 deletions py-polars/polars/internals/expr.py
Original file line number Diff line number Diff line change
Expand Up @@ -2133,7 +2133,7 @@ def ewm_mean(
Minimum number of observations in window required to have a value (otherwise result is Null).
"""
_prepare_alpha(com, span, half_life, alpha)
alpha = _prepare_alpha(com, span, half_life, alpha)
return wrap_expr(self._pyexpr.ewm_mean(alpha, adjust, min_periods))

def ewm_std(
Expand Down Expand Up @@ -2167,7 +2167,7 @@ def ewm_std(
Minimum number of observations in window required to have a value (otherwise result is Null).
"""
_prepare_alpha(com, span, half_life, alpha)
alpha = _prepare_alpha(com, span, half_life, alpha)
return wrap_expr(self._pyexpr.ewm_std(alpha, adjust, min_periods))

def ewm_var(
Expand Down Expand Up @@ -2201,7 +2201,7 @@ def ewm_var(
Minimum number of observations in window required to have a value (otherwise result is Null).
"""
_prepare_alpha(com, span, half_life, alpha)
alpha = _prepare_alpha(com, span, half_life, alpha)
return wrap_expr(self._pyexpr.ewm_var(alpha, adjust, min_periods))

def extend(self, value: Optional[Union[int, float, str, bool]], n: int) -> "Expr":
Expand Down Expand Up @@ -2908,16 +2908,15 @@ def _prepare_alpha(
half_life: Optional[float] = None,
alpha: Optional[float] = None,
) -> float:

if com is not None and alpha is not None:
if com is not None and alpha is None:
assert com >= 0.0
alpha = 1.0 / (1.0 + com)
if span is not None and alpha is not None:
if span is not None and alpha is None:
assert span >= 1.0
alpha = 2.0 / (span + 1.0)
if half_life is not None and alpha is not None:
if half_life is not None and alpha is None:
assert half_life > 0.0
alpha = 1.0 - np.exp(-np.log(2.0) / half_life)
if alpha is None:
raise ValueError("at least one of {com, span, halflife, alpha} should be set")
raise ValueError("at least one of {com, span, half_life, alpha} should be set")
return alpha
2 changes: 2 additions & 0 deletions py-polars/tests/test_series.py
Original file line number Diff line number Diff line change
Expand Up @@ -1275,6 +1275,8 @@ def test_ewm_mean() -> None:
],
)
verify_series_and_expr_api(a, expected, "ewm_mean", alpha=0.5, adjust=True)
expected = pl.Series("a", [2.0, 3.8, 3.421053])
verify_series_and_expr_api(a, expected, "ewm_mean", com=2.0, adjust=True)
expected = pl.Series("a", [2.0, 3.5, 3.25])
verify_series_and_expr_api(a, expected, "ewm_mean", alpha=0.5, adjust=False)
a = pl.Series("a", [2, 3, 5, 7, 4])
Expand Down

0 comments on commit 99a6893

Please sign in to comment.