Skip to content

Commit

Permalink
add support for comparison of pl.Series and scalar with different typ…
Browse files Browse the repository at this point in the history
…es, add mapping of polars-types to python types and add tests
  • Loading branch information
CloseChoice authored and ritchie46 committed Oct 19, 2021
1 parent bf94b7e commit fc90cf5
Show file tree
Hide file tree
Showing 3 changed files with 96 additions and 0 deletions.
15 changes: 15 additions & 0 deletions py-polars/polars/datatypes.py
Original file line number Diff line number Diff line change
Expand Up @@ -294,6 +294,21 @@ def numpy_type_to_constructor(dtype: Type[np.dtype]) -> Callable[..., "PySeries"
bool: Boolean,
}

_DTYPE_TO_PY_TYPE = {
Float64: float,
Float32: float,
Int64: int,
Int32: int,
Int16: int,
Int8: int,
Utf8: str,
UInt8: int,
UInt16: int,
UInt32: int,
UInt64: int,
Boolean: bool,
}


def py_type_to_constructor(dtype: Type[Any]) -> Callable[..., "PySeries"]:
"""
Expand Down
22 changes: 22 additions & 0 deletions py-polars/polars/eager/series.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
_DOCUMENTING = True

from ..datatypes import (
_DTYPE_TO_PY_TYPE,
DTYPE_TO_FFINAME,
DTYPES,
Boolean,
Expand Down Expand Up @@ -113,6 +114,13 @@ def wrap_s(s: "PySeries") -> "Series":
return Series._from_pyseries(s)


def _maybe_cast(el: "Type[DataType]", dtype: Type) -> "Type[DataType]":
# cast el if it doesn't match
if not isinstance(el, _DTYPE_TO_PY_TYPE[dtype]):
el = _DTYPE_TO_PY_TYPE[dtype](el)
return el


ArrayLike = Union[
Sequence[Any], "Series", "pa.Array", np.ndarray, "pd.Series", "pd.DatetimeIndex"
]
Expand Down Expand Up @@ -298,6 +306,7 @@ def __eq__(self, other: Any) -> "Series": # type: ignore[override]
other = Series("", other)
if isinstance(other, Series):
return Series._from_pyseries(self._s.eq(other._s))
other = _maybe_cast(other, self.dtype)
f = get_ffi_func("eq_<>", self.dtype, self._s)
if f is None:
return NotImplemented
Expand All @@ -308,6 +317,7 @@ def __ne__(self, other: Any) -> "Series": # type: ignore[override]
other = Series("", other)
if isinstance(other, Series):
return Series._from_pyseries(self._s.neq(other._s))
other = _maybe_cast(other, self.dtype)
f = get_ffi_func("neq_<>", self.dtype, self._s)
if f is None:
return NotImplemented
Expand All @@ -318,6 +328,7 @@ def __gt__(self, other: Any) -> "Series":
other = Series("", other)
if isinstance(other, Series):
return Series._from_pyseries(self._s.gt(other._s))
other = _maybe_cast(other, self.dtype)
f = get_ffi_func("gt_<>", self.dtype, self._s)
if f is None:
return NotImplemented
Expand All @@ -328,6 +339,8 @@ def __lt__(self, other: Any) -> "Series":
other = Series("", other)
if isinstance(other, Series):
return Series._from_pyseries(self._s.lt(other._s))
# cast other if it doesn't match
other = _maybe_cast(other, self.dtype)
f = get_ffi_func("lt_<>", self.dtype, self._s)
if f is None:
return NotImplemented
Expand All @@ -338,6 +351,7 @@ def __ge__(self, other: Any) -> "Series":
other = Series("", other)
if isinstance(other, Series):
return Series._from_pyseries(self._s.gt_eq(other._s))
other = _maybe_cast(other, self.dtype)
f = get_ffi_func("gt_eq_<>", self.dtype, self._s)
if f is None:
return NotImplemented
Expand All @@ -348,6 +362,7 @@ def __le__(self, other: Any) -> "Series":
other = Series("", other)
if isinstance(other, Series):
return Series._from_pyseries(self._s.lt_eq(other._s))
other = _maybe_cast(other, self.dtype)
f = get_ffi_func("lt_eq_<>", self.dtype, self._s)
if f is None:
return NotImplemented
Expand All @@ -358,6 +373,7 @@ def __add__(self, other: Any) -> "Series":
other = Series("", [other])
if isinstance(other, Series):
return wrap_s(self._s.add(other._s))
other = _maybe_cast(other, self.dtype)
dtype = date_like_to_physical(self.dtype)
f = get_ffi_func("add_<>", dtype, self._s)
if f is None:
Expand All @@ -367,6 +383,7 @@ def __add__(self, other: Any) -> "Series":
def __sub__(self, other: Any) -> "Series":
if isinstance(other, Series):
return Series._from_pyseries(self._s.sub(other._s))
other = _maybe_cast(other, self.dtype)
dtype = date_like_to_physical(self.dtype)
f = get_ffi_func("sub_<>", dtype, self._s)
if f is None:
Expand All @@ -382,13 +399,15 @@ def __truediv__(self, other: Any) -> "Series":
def __floordiv__(self, other: Any) -> "Series":
if isinstance(other, Series):
return Series._from_pyseries(self._s.div(other._s))
other = _maybe_cast(other, self.dtype)
dtype = date_like_to_physical(self.dtype)
f = get_ffi_func("div_<>", dtype, self._s)
return wrap_s(f(other))

def __mul__(self, other: Any) -> "Series":
if isinstance(other, Series):
return Series._from_pyseries(self._s.mul(other._s))
other = _maybe_cast(other, self.dtype)
dtype = date_like_to_physical(self.dtype)
f = get_ffi_func("mul_<>", dtype, self._s)
if f is None:
Expand All @@ -398,6 +417,7 @@ def __mul__(self, other: Any) -> "Series":
def __mod__(self, other: Any) -> "Series":
if isinstance(other, Series):
return Series._from_pyseries(self._s.rem(other._s))
other = _maybe_cast(other, self.dtype)
dtype = date_like_to_physical(self.dtype)
f = get_ffi_func("rem_<>", dtype, self._s)
if f is None:
Expand All @@ -408,6 +428,7 @@ def __rmod__(self, other: Any) -> "Series":
if isinstance(other, Series):
return Series._from_pyseries(other._s.rem(self._s))
dtype = date_like_to_physical(self.dtype)
other = _maybe_cast(other, self.dtype)
other = match_dtype(other, dtype)
f = get_ffi_func("rem_<>_rhs", dtype, self._s)
if f is None:
Expand All @@ -418,6 +439,7 @@ def __radd__(self, other: Any) -> "Series":
if isinstance(other, Series):
return Series._from_pyseries(self._s.add(other._s))
dtype = date_like_to_physical(self.dtype)
other = _maybe_cast(other, self.dtype)
other = match_dtype(other, dtype)
f = get_ffi_func("add_<>_rhs", dtype, self._s)
if f is None:
Expand Down
59 changes: 59 additions & 0 deletions py-polars/tests/test_series.py
Original file line number Diff line number Diff line change
Expand Up @@ -697,3 +697,62 @@ def test_from_sequences():
b = pl.Series("a", vals)
assert a.series_equal(b, null_equal=True)
assert a.to_list() == vals


def test_comparisons_int_series_to_float():
srs_int = pl.Series([1, 2, 3, 4])
assert (srs_int - 1.0).to_list() == [0, 1, 2, 3]
assert (srs_int + 1.0).to_list() == [2, 3, 4, 5]
assert (srs_int * 2.0).to_list() == [2, 4, 6, 8]
# todo: this is inconsistant
assert (srs_int / 2.0).to_list() == [0.5, 1.0, 1.5, 2.0]
assert (srs_int % 2.0).to_list() == [1, 0, 1, 0]
assert (4.0 % srs_int).to_list() == [0, 0, 1, 0]
# floordiv is implemented as div
assert (srs_int // 2.0).to_list() == [0, 1, 1, 2]
assert (srs_int < 3.0).to_list() == [True, True, False, False]
assert (srs_int <= 3.0).to_list() == [True, True, True, False]
assert (srs_int > 3.0).to_list() == [False, False, False, True]
assert (srs_int >= 3.0).to_list() == [False, False, True, True]
assert (srs_int == 3.0).to_list() == [False, False, True, False]
assert (srs_int - True).to_list() == [0, 1, 2, 3]


def test_comparisons_float_series_to_int():
srs_float = pl.Series([1.0, 2.0, 3.0, 4.0])
assert (srs_float - 1).to_list() == [0.0, 1.0, 2.0, 3.0]
assert (srs_float + 1).to_list() == [2.0, 3.0, 4.0, 5.0]
assert (srs_float * 2).to_list() == [2.0, 4.0, 6.0, 8.0]
assert (srs_float / 2).to_list() == [0.5, 1.0, 1.5, 2.0]
assert (srs_float % 2).to_list() == [1.0, 0.0, 1.0, 0.0]
assert (4 % srs_float).to_list() == [0.0, 0.0, 1.0, 0.0]
# floordiv is implemented as div
assert (srs_float // 2).to_list() == [0.5, 1.0, 1.5, 2.0]
assert (srs_float < 3).to_list() == [True, True, False, False]
assert (srs_float <= 3).to_list() == [True, True, True, False]
assert (srs_float > 3).to_list() == [False, False, False, True]
assert (srs_float >= 3).to_list() == [False, False, True, True]
assert (srs_float == 3).to_list() == [False, False, True, False]
assert (srs_float - True).to_list() == [0.0, 1.0, 2.0, 3.0]


def test_comparisons_bool_series_to_int():
srs_bool = pl.Series([True, False])
# todo: do we want this to work?
assert (srs_bool / 1).to_list() == [True, False]
with pytest.raises(TypeError, match=r"\-: 'Series' and 'int'"):
srs_bool - 1
with pytest.raises(TypeError, match=r"\+: 'Series' and 'int'"):
srs_bool + 1
with pytest.raises(TypeError, match=r"\%: 'Series' and 'int'"):
srs_bool % 2
with pytest.raises(TypeError, match=r"\*: 'Series' and 'int'"):
srs_bool * 1
with pytest.raises(
TypeError, match=r"'<' not supported between instances of 'Series' and 'int'"
):
srs_bool < 2
with pytest.raises(
TypeError, match=r"'>' not supported between instances of 'Series' and 'int'"
):
srs_bool > 2

0 comments on commit fc90cf5

Please sign in to comment.