Skip to content

Commit

Permalink
python: __rpow__ and value error on datelike arithmetic
Browse files Browse the repository at this point in the history
  • Loading branch information
ritchie46 committed Jan 7, 2022
1 parent 814eac5 commit bf6e189
Show file tree
Hide file tree
Showing 2 changed files with 56 additions and 1 deletion.
35 changes: 34 additions & 1 deletion py-polars/polars/internals/series.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@
from polars.datatypes import List as PlList
from polars.datatypes import (
Object,
Time,
UInt8,
UInt16,
UInt32,
Expand Down Expand Up @@ -356,25 +357,40 @@ def __sub__(self, other: Any) -> "Series":
return self._arithmetic(other, "sub", "sub_<>")

def __truediv__(self, other: Any) -> "Series":
if self.is_datelike():
raise ValueError("first cast to integer before dividing datelike dtypes")

# this branch is exactly the floordiv function without rounding the floats
if self.is_float():
return self._arithmetic(other, "div", "div_<>")

return self.cast(Float64) / other

def __floordiv__(self, other: Any) -> "Series":
if self.is_datelike():
raise ValueError("first cast to integer before dividing datelike dtypes")
result = self._arithmetic(other, "div", "div_<>")
if self.is_float():
result = result.floor()
return result

def __mul__(self, other: Any) -> "Series":
if self.is_datelike():
raise ValueError("first cast to integer before multiplying datelike dtypes")
return self._arithmetic(other, "mul", "mul_<>")

def __mod__(self, other: Any) -> "Series":
if self.is_datelike():
raise ValueError(
"first cast to integer before applying modulo on datelike dtypes"
)
return self._arithmetic(other, "rem", "rem_<>")

def __rmod__(self, other: Any) -> "Series":
if self.is_datelike():
raise ValueError(
"first cast to integer before applying modulo on datelike dtypes"
)
return self._arithmetic(other, "rem", "rem_<>_rhs")

def __radd__(self, other: Any) -> "Series":
Expand All @@ -389,6 +405,8 @@ def __invert__(self) -> "Series":
return NotImplemented

def __rtruediv__(self, other: Any) -> np.ndarray:
if self.is_datelike():
raise ValueError("first cast to integer before dividing datelike dtypes")
if self.is_float():
self.__rfloordiv__(other)

Expand All @@ -398,14 +416,29 @@ def __rtruediv__(self, other: Any) -> np.ndarray:
return self.cast(Float64).__rfloordiv__(other) # type: ignore

def __rfloordiv__(self, other: Any) -> "Series":
if self.is_datelike():
raise ValueError("first cast to integer before dividing datelike dtypes")
return self._arithmetic(other, "div", "div_<>_rhs")

def __rmul__(self, other: Any) -> "Series":
if self.is_datelike():
raise ValueError("first cast to integer before multiplying datelike dtypes")
return self._arithmetic(other, "mul", "mul_<>")

def __pow__(self, power: float, modulo: None = None) -> "Series":
if self.is_datelike():
raise ValueError(
"first cast to integer before raising datelike dtypes to a power"
)
return np.power(self, power) # type: ignore

def __rpow__(self, other: Any) -> "Series":
if self.is_datelike():
raise ValueError(
"first cast to integer before raising datelike dtypes to a power"
)
return np.power(other, self) # type: ignore

def __neg__(self) -> "Series":
return 0 - self

Expand Down Expand Up @@ -1847,7 +1880,7 @@ def is_datelike(self) -> bool:
True
"""
return self.dtype in (Date, Datetime, Duration)
return self.dtype in (Date, Datetime, Duration, Time)

def is_float(self) -> bool:
"""
Expand Down
22 changes: 22 additions & 0 deletions py-polars/tests/test_series.py
Original file line number Diff line number Diff line change
Expand Up @@ -182,6 +182,28 @@ def test_arithmetic(s: pl.Series) -> None:
assert ((1.0 + a) == [2, 3]).sum() == 2
assert ((1.0 % a) == [0, 1]).sum() == 2

a = pl.Series("a", [datetime(2021, 1, 1)])
with pytest.raises(ValueError):
a // 2
with pytest.raises(ValueError):
a / 2
with pytest.raises(ValueError):
a * 2
with pytest.raises(ValueError):
a % 2
with pytest.raises(ValueError):
a ** 2
with pytest.raises(ValueError):
2 / a
with pytest.raises(ValueError):
2 // a
with pytest.raises(ValueError):
2 * a
with pytest.raises(ValueError):
2 % a
with pytest.raises(ValueError):
2 ** a


def test_add_string() -> None:
s = pl.Series(["hello", "weird"])
Expand Down

0 comments on commit bf6e189

Please sign in to comment.