Skip to content

Commit

Permalink
fix(python): Refactor is_between (#5491)
Browse files Browse the repository at this point in the history
  • Loading branch information
zundertj committed Nov 13, 2022
1 parent dd5d814 commit f304c96
Show file tree
Hide file tree
Showing 2 changed files with 45 additions and 18 deletions.
25 changes: 7 additions & 18 deletions py-polars/polars/internals/expr/expr.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@
from polars import internals as pli
from polars.datatypes import (
DataType,
Datetime,
PolarsDataType,
UInt32,
is_polars_dtype,
Expand Down Expand Up @@ -3527,8 +3526,8 @@ def repeat_by(self, by: Expr | str) -> Expr:

def is_between(
self,
start: Expr | datetime | int,
end: Expr | datetime | int,
start: Expr | datetime | date | int | float,
end: Expr | datetime | date | int | float,
include_bounds: bool | tuple[bool, bool] = False,
) -> Expr:
"""
Expand Down Expand Up @@ -3574,32 +3573,22 @@ def is_between(
└─────┴────────────┘
"""
cast_to_datetime = False
if isinstance(start, datetime):
start = pli.lit(start)
cast_to_datetime = True
if isinstance(end, datetime):
end = pli.lit(end)
cast_to_datetime = True
if cast_to_datetime:
expr = self.cast(Datetime)
else:
expr = self
if isinstance(include_bounds, list):
warnings.warn(
"include_bounds: list[bool] will not be supported in a future "
"version; pass include_bounds: tuple[bool, bool] instead",
category=DeprecationWarning,
)
include_bounds = tuple(include_bounds)

if include_bounds is False or include_bounds == (False, False):
return ((expr > start) & (expr < end)).alias("is_between")
return ((self > start) & (self < end)).alias("is_between")
elif include_bounds is True or include_bounds == (True, True):
return ((expr >= start) & (expr <= end)).alias("is_between")
return ((self >= start) & (self <= end)).alias("is_between")
elif include_bounds == (False, True):
return ((expr > start) & (expr <= end)).alias("is_between")
return ((self > start) & (self <= end)).alias("is_between")
elif include_bounds == (True, False):
return ((expr >= start) & (expr < end)).alias("is_between")
return ((self >= start) & (self < end)).alias("is_between")
else:
raise ValueError("include_bounds should be a bool or tuple[bool, bool].")

Expand Down
38 changes: 38 additions & 0 deletions py-polars/tests/unit/test_lazy.py
Original file line number Diff line number Diff line change
Expand Up @@ -1210,6 +1210,44 @@ def test_is_between(fruits_cars: pl.DataFrame) -> None:
)


def test_is_between_data_types() -> None:
df = pl.DataFrame(
{
"flt": [1.4, 1.2, 2.5],
"int": [2, 3, 4],
"date": [date(2020, 1, 1), date(2020, 2, 2), date(2020, 3, 3)],
"datetime": [
datetime(2020, 1, 1, 0, 0, 0),
datetime(2020, 1, 1, 10, 0, 0),
datetime(2020, 1, 1, 12, 0, 0),
],
}
)

# on purpose, for float and int, we pass in a mixture of bound data types
assert_series_equal(
df.select(pl.col("flt").is_between(1, 2.3))[:, 0],
pl.Series("is_between", [True, True, False]),
)
assert_series_equal(
df.select(pl.col("int").is_between(1.5, 4))[:, 0],
pl.Series("is_between", [True, True, False]),
)

assert_series_equal(
df.select(pl.col("date").is_between(date(2019, 1, 1), date(2020, 2, 5)))[:, 0],
pl.Series("is_between", [True, True, False]),
)
assert_series_equal(
df.select(
pl.col("datetime").is_between(
datetime(2020, 1, 1, 5, 0, 0), datetime(2020, 1, 1, 11, 0, 0)
)
)[:, 0],
pl.Series("is_between", [False, True, False]),
)


def test_unique() -> None:
df = pl.DataFrame({"a": [1, 2, 2], "b": [3, 3, 3]})

Expand Down

0 comments on commit f304c96

Please sign in to comment.