Skip to content

Commit

Permalink
Implement pow/rpow for Series (#3908)
Browse files Browse the repository at this point in the history
  • Loading branch information
stinodego committed Jul 6, 2022
1 parent 8b8630b commit c46f78b
Show file tree
Hide file tree
Showing 3 changed files with 32 additions and 11 deletions.
6 changes: 3 additions & 3 deletions py-polars/polars/internals/expr.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,10 +141,10 @@ def __mod__(self, other: Any) -> Expr:
def __rmod__(self, other: Any) -> Expr:
return wrap_expr(self.__to_pyexpr(other) % self._pyexpr)

def __pow__(self, power: float | Expr | int, modulo: None = None) -> Expr:
def __pow__(self, power: int | float | pli.Series | Expr) -> Expr:
return self.pow(power)

def __rpow__(self, base: float | Expr | int) -> Expr:
def __rpow__(self, base: int | float | Expr) -> Expr:
return pli.expr_to_lit_or_expr(base) ** self

def __ge__(self, other: Any) -> Expr:
Expand Down Expand Up @@ -2076,7 +2076,7 @@ def tail(self, n: int | None = None) -> Expr:
"""
return wrap_expr(self._pyexpr.tail(n))

def pow(self, exponent: float | Expr) -> Expr:
def pow(self, exponent: int | float | pli.Series | Expr) -> Expr:
"""
Raise expression to the power of exponent.
Examples
Expand Down
10 changes: 3 additions & 7 deletions py-polars/polars/internals/series.py
Original file line number Diff line number Diff line change
Expand Up @@ -434,23 +434,19 @@ def __rmul__(self, other: Any) -> Series:
raise ValueError("first cast to integer before multiplying datelike dtypes")
return self._arithmetic(other, "mul", "mul_<>")

def __pow__(self, power: float, modulo: None = None) -> Series:
def __pow__(self, power: int | float | Series) -> Series:
if self.is_datelike():
raise ValueError(
"first cast to integer before raising datelike dtypes to a power"
)
if not _NUMPY_AVAILABLE:
raise ImportError("'numpy' is required for this functionality.")
return np.power(self, power) # type: ignore
return self.to_frame().select(pli.col(self.name).pow(power)).to_series()

def __rpow__(self, other: Any) -> Series:
if self.is_datelike():
raise ValueError(
"first cast to integer before raising datelike dtypes to a power"
)
if not _NUMPY_AVAILABLE:
raise ImportError("'numpy' is required for this functionality.")
return np.power(other, self) # type: ignore
return self.to_frame().select(other ** pli.col(self.name)).to_series()

def __neg__(self) -> Series:
return 0 - self
Expand Down
27 changes: 26 additions & 1 deletion py-polars/tests/test_series.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
import pytest

import polars as pl
from polars.datatypes import Float64, Int32, Int64, UInt32, UInt64
from polars.datatypes import Date, Float64, Int32, Int64, UInt32, UInt64
from polars.testing import assert_series_equal, verify_series_and_expr_api


Expand Down Expand Up @@ -193,6 +193,31 @@ def test_arithmetic(s: pl.Series) -> None:
2**a


def test_power() -> None:
a = pl.Series([1, 2], dtype=Int64)
b = pl.Series([None, 2.0], dtype=Float64)
c = pl.Series([date(2020, 2, 28), date(2020, 3, 1)], dtype=Date)

# pow
assert_series_equal(a**2, pl.Series([1.0, 4.0], dtype=Float64))
assert_series_equal(b**3, pl.Series([None, 8.0], dtype=Float64))
assert_series_equal(a**a, pl.Series([1.0, 4.0], dtype=Float64))
assert_series_equal(b**b, pl.Series([None, 4.0], dtype=Float64))
assert_series_equal(a**b, pl.Series([None, 4.0], dtype=Float64))
with pytest.raises(ValueError):
c**2
with pytest.raises(pl.ComputeError):
a ** "hi" # type: ignore[operator]

# rpow
assert_series_equal(2.0**a, pl.Series("literal", [2.0, 4.0], dtype=Float64))
assert_series_equal(2**b, pl.Series("literal", [None, 4.0], dtype=Float64))
with pytest.raises(ValueError):
2**c
with pytest.raises(pl.ComputeError):
"hi" ** a


def test_add_string() -> None:
s = pl.Series(["hello", "weird"])
result = s + " world"
Expand Down

0 comments on commit c46f78b

Please sign in to comment.