Skip to content

Commit

Permalink
Refactor Series arithmetic methods (#2121)
Browse files Browse the repository at this point in the history
* Refactor Series arithmetic methods

As with the comparison methods.

Note that I have removed the match_dtype function calls, maybe_cast already does the same, with the exception that it will no longer cast strings and booleans to integers. I think that is desired behaviour as no test is failing.

* Revert f is None checks
  • Loading branch information
zundertj committed Dec 22, 2021
1 parent 29350bb commit f952801
Showing 1 changed file with 30 additions and 102 deletions.
132 changes: 30 additions & 102 deletions py-polars/polars/internals/series.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,16 +79,6 @@
from typing_extensions import Literal


def match_dtype(value: Any, dtype: "Type[DataType]") -> Any:
"""
In right hand side operation, make sure that the operand is coerced to the Series dtype
"""
if dtype == Float32 or dtype == Float64:
return float(value)
else:
return int(value)


def get_ffi_func(
name: str,
dtype: Type["DataType"],
Expand Down Expand Up @@ -316,7 +306,7 @@ def _comp(self, other: Any, op: str) -> "Series":
if isinstance(other, Sequence) and not isinstance(other, str):
other = Series("", other)
if isinstance(other, Series):
return Series._from_pyseries(getattr(self._s, op)(other._s))
return wrap_s(getattr(self._s, op)(other._s))
other = maybe_cast(other, self.dtype)
f = get_ffi_func(op + "_<>", self.dtype, self._s)
if f is None:
Expand All @@ -341,100 +331,50 @@ def __ge__(self, other: Any) -> "Series":
def __le__(self, other: Any) -> "Series":
return self._comp(other, "lt_eq")

def __add__(self, other: Any) -> "Series":
if isinstance(other, str):
other = Series("", [other])
def _arithmetic(self, other: Any, op_s: str, op_ffi: str) -> "Series":
if isinstance(other, Series):
return wrap_s(self._s.add(other._s))
return wrap_s(getattr(self._s, op_s)(other._s))
other = maybe_cast(other, self.dtype)
f = get_ffi_func("add_<>", self.dtype, self._s)
f = get_ffi_func(op_ffi, self.dtype, self._s)
if f is None:
return NotImplemented
return wrap_s(f(other))

def __add__(self, other: Any) -> "Series":
if isinstance(other, str):
other = Series("", [other])
return self._arithmetic(other, "add", "add_<>")

def __sub__(self, other: Any) -> "Series":
if isinstance(other, Series):
return Series._from_pyseries(self._s.sub(other._s))
other = maybe_cast(other, self.dtype)
f = get_ffi_func("sub_<>", self.dtype, self._s)
if f is None:
return NotImplemented
return wrap_s(f(other))
return self._arithmetic(other, "sub", "sub_<>")

def __truediv__(self, other: Any) -> "Series":
# this branch is exactly the floordiv function without rounding the floats
if self.is_float():
if isinstance(other, Series):
return Series._from_pyseries(self._s.div(other._s))

other = maybe_cast(other, self.dtype)
f = get_ffi_func("div_<>", self.dtype, self._s)
if f is None:
return NotImplemented
return wrap_s(f(other))
return self._arithmetic(other, "div", "div_<>")

return self.cast(Float64) / other

def __floordiv__(self, other: Any) -> "Series":
if isinstance(other, Series):
if self.is_float():
return Series._from_pyseries(self._s.div(other._s)).floor()
return Series._from_pyseries(self._s.div(other._s))

other = maybe_cast(other, self.dtype)
f = get_ffi_func("div_<>", self.dtype, self._s)
if f is None:
return NotImplemented
result = self._arithmetic(other, "div", "div_<>")
if self.is_float():
return wrap_s(f(other)).floor()
return wrap_s(f(other))
result = result.floor()
return result

def __mul__(self, other: Any) -> "Series":
if isinstance(other, Series):
return Series._from_pyseries(self._s.mul(other._s))
other = maybe_cast(other, self.dtype)
f = get_ffi_func("mul_<>", self.dtype, self._s)
if f is None:
return NotImplemented
return wrap_s(f(other))
return self._arithmetic(other, "mul", "mul_<>")

def __mod__(self, other: Any) -> "Series":
if isinstance(other, Series):
return Series._from_pyseries(self._s.rem(other._s))
other = maybe_cast(other, self.dtype)
f = get_ffi_func("rem_<>", self.dtype, self._s)
if f is None:
return NotImplemented
return wrap_s(f(other))
return self._arithmetic(other, "rem", "rem_<>")

def __rmod__(self, other: Any) -> "Series":
if isinstance(other, Series):
return Series._from_pyseries(other._s.rem(self._s))
other = maybe_cast(other, self.dtype)
other = match_dtype(other, self.dtype)
f = get_ffi_func("rem_<>_rhs", self.dtype, self._s)
if f is None:
return NotImplemented
return wrap_s(f(other))
return self._arithmetic(other, "rem", "rem_<>_rhs")

def __radd__(self, other: Any) -> "Series":
if isinstance(other, Series):
return Series._from_pyseries(self._s.add(other._s))
other = maybe_cast(other, self.dtype)
other = match_dtype(other, self.dtype)
f = get_ffi_func("add_<>_rhs", self.dtype, self._s)
if f is None:
return NotImplemented
return wrap_s(f(other))
return self._arithmetic(other, "add", "add_<>_rhs")

def __rsub__(self, other: Any) -> "Series":
if isinstance(other, Series):
return Series._from_pyseries(other._s.sub(self._s))
other = match_dtype(other, self.dtype)
f = get_ffi_func("sub_<>_rhs", self.dtype, self._s)
if f is None:
return NotImplemented
return wrap_s(f(other))
return self._arithmetic(other, "sub", "sub_<>_rhs")

def __invert__(self) -> "Series":
if self.dtype == Boolean:
Expand All @@ -451,22 +391,10 @@ def __rtruediv__(self, other: Any) -> np.ndarray:
return self.cast(Float64).__rfloordiv__(other) # type: ignore

def __rfloordiv__(self, other: Any) -> "Series":
if isinstance(other, Series):
return Series._from_pyseries(other._s.div(self._s))
other = match_dtype(other, self.dtype)
f = get_ffi_func("div_<>_rhs", self.dtype, self._s)
if f is None:
return NotImplemented
return wrap_s(f(other))
return self._arithmetic(other, "div", "div_<>_rhs")

def __rmul__(self, other: Any) -> "Series":
if isinstance(other, Series):
return Series._from_pyseries(self._s.mul(other._s))
other = match_dtype(other, self.dtype)
f = get_ffi_func("mul_<>", self.dtype, self._s)
if f is None:
return NotImplemented
return wrap_s(f(other))
return self._arithmetic(other, "mul", "mul_<>")

def __pow__(self, power: float, modulo: None = None) -> "Series":
return np.power(self, power) # type: ignore
Expand All @@ -492,7 +420,7 @@ def __getitem__(self, item: Any) -> Any:
return self._s.get_idx(item)
# assume it is boolean mask
if isinstance(item, Series):
return Series._from_pyseries(self._s.filter(item._s))
return wrap_s(self._s.filter(item._s))

if isinstance(item, range):
step: Optional[int]
Expand Down Expand Up @@ -1385,7 +1313,7 @@ def take(self, indices: Union[np.ndarray, List[int]]) -> "Series":
"""
if isinstance(indices, list):
indices = np.array(indices)
return Series._from_pyseries(self._s.take(indices))
return wrap_s(self._s.take(indices))

def null_count(self) -> int:
"""
Expand Down Expand Up @@ -1422,7 +1350,7 @@ def is_null(self) -> "Series":
]
"""
return Series._from_pyseries(self._s.is_null())
return wrap_s(self._s.is_null())

def is_not_null(self) -> "Series":
"""
Expand All @@ -1446,7 +1374,7 @@ def is_not_null(self) -> "Series":
]
"""
return Series._from_pyseries(self._s.is_not_null())
return wrap_s(self._s.is_not_null())

def is_finite(self) -> "Series":
"""
Expand All @@ -1470,7 +1398,7 @@ def is_finite(self) -> "Series":
]
"""
return Series._from_pyseries(self._s.is_finite())
return wrap_s(self._s.is_finite())

def is_infinite(self) -> "Series":
"""
Expand All @@ -1494,7 +1422,7 @@ def is_infinite(self) -> "Series":
]
"""
return Series._from_pyseries(self._s.is_infinite())
return wrap_s(self._s.is_infinite())

def is_nan(self) -> "Series":
"""
Expand All @@ -1519,7 +1447,7 @@ def is_nan(self) -> "Series":
]
"""
return Series._from_pyseries(self._s.is_nan())
return wrap_s(self._s.is_nan())

def is_not_nan(self) -> "Series":
"""
Expand All @@ -1544,7 +1472,7 @@ def is_not_nan(self) -> "Series":
]
"""
return Series._from_pyseries(self._s.is_not_nan())
return wrap_s(self._s.is_not_nan())

def is_in(self, other: Union["Series", List]) -> "Series":
"""
Expand Down Expand Up @@ -1607,7 +1535,7 @@ def arg_true(self) -> "Series":
-------
UInt32 Series
"""
return Series._from_pyseries(self._s.arg_true())
return wrap_s(self._s.arg_true())

def is_unique(self) -> "Series":
"""
Expand Down

0 comments on commit f952801

Please sign in to comment.