Skip to content

Commit

Permalink
Small improvements for is_in and lit (#4354)
Browse files Browse the repository at this point in the history
  • Loading branch information
alexander-beedie committed Aug 11, 2022
1 parent 8f73721 commit 0cbbf48
Show file tree
Hide file tree
Showing 6 changed files with 39 additions and 27 deletions.
9 changes: 4 additions & 5 deletions py-polars/polars/internals/expr.py
Original file line number Diff line number Diff line change
Expand Up @@ -2986,15 +2986,14 @@ def pow(self, exponent: int | float | pli.Series | Expr) -> Expr:
exponent = expr_to_lit_or_expr(exponent)
return wrap_expr(self._pyexpr.pow(exponent._pyexpr))

def is_in(self, other: Expr | List[Any] | str) -> Expr:
def is_in(self, other: Expr | Sequence[Any] | str) -> Expr:
"""
Check if elements of this Series are in the right Series, or List values of the
right Series.
Check if elements of this expression are present in the other Series.
Parameters
----------
other
Series of primitive type or List type.
Series or sequence of primitive type.
Returns
-------
Expand All @@ -3020,7 +3019,7 @@ def is_in(self, other: Expr | List[Any] | str) -> Expr:
└──────────┘
"""
if isinstance(other, list):
if isinstance(other, Sequence) and not isinstance(other, str):
other = pli.lit(pli.Series(other))
else:
other = expr_to_lit_or_expr(other, str_to_lit=False)
Expand Down
31 changes: 18 additions & 13 deletions py-polars/polars/internals/lazy_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -642,23 +642,15 @@ def lit(value: Any, dtype: type[DataType] | None = None) -> pli.Expr:
"""
if isinstance(value, datetime):
tu = "us"
return (
lit(_datetime_to_pl_timestamp(value, tu))
.cast(Datetime)
.dt.with_time_unit(tu)
)
return lit(_datetime_to_pl_timestamp(value, tu)).cast(Datetime(tu))
if isinstance(value, timedelta):
# TODO: python timedelta should also default to 'us' units.
# (needs some corresponding work on the Rust side first)
if timedelta_in_nanoseconds_window(value):
tu = "ns"
else:
tu = "ms"
return (
lit(_timedelta_to_pl_timedelta(value, tu))
.cast(Duration)
.dt.with_time_unit(tu)
)
return lit(_timedelta_to_pl_timedelta(value, tu)).cast(Duration(tu))

if isinstance(value, date):
return lit(datetime(value.year, value.month, value.day)).cast(Date)
Expand All @@ -678,12 +670,25 @@ def lit(value: Any, dtype: type[DataType] | None = None) -> pli.Expr:
return pli.wrap_expr(pylit(value)).cast(dtype)

try:
# numpy literals like np.float32(0) have an item
# numpy literals like np.float32(0) have item/dtype
item = value.item()

# numpy item() is py-native datetime/timedelta when units < 'ns'
if isinstance(item, (datetime, timedelta)):
return lit(item)

# handle 'ns' units
if isinstance(item, int) and hasattr(value, "dtype"):
dtype_name = value.dtype.name
if dtype_name.startswith(("datetime64[", "timedelta64[")):
tu = dtype_name[11:-1]
return lit(item).cast(
Datetime(tu) if dtype_name.startswith("date") else Duration(tu)
)

except AttributeError:
item = value
if isinstance(item, datetime):
return lit(item)

return pli.wrap_expr(pylit(item))


Expand Down
12 changes: 7 additions & 5 deletions py-polars/polars/internals/series.py
Original file line number Diff line number Diff line change
Expand Up @@ -204,7 +204,7 @@ class Series:
def __init__(
self,
name: str | ArrayLike | None = None,
values: ArrayLike | None = None,
values: ArrayLike | Sequence[Any] | None = None,
dtype: type[DataType] | DataType | None = None,
strict: bool = True,
nan_to_null: bool = False,
Expand Down Expand Up @@ -1837,10 +1837,10 @@ def is_not_nan(self) -> Series:
"""
return wrap_s(self._s.is_not_nan())

def is_in(self, other: Series | list[object]) -> Series:
def is_in(self, other: Series | Sequence[object]) -> Series:
"""
Check if elements of this Series are in the right Series, or List values of the
right Series.
Check if elements of this Series are in the other Series, or
if this Series is itself a member of the other Series.
Returns
-------
Expand Down Expand Up @@ -1887,7 +1887,9 @@ def is_in(self, other: Series | list[object]) -> Series:
]
"""
if isinstance(other, list):
if isinstance(other, str):
raise TypeError("'other' parameter expects non-string sequence data")
elif isinstance(other, Sequence):
other = Series("", other)
return wrap_s(self._s.is_in(other._s))

Expand Down
2 changes: 1 addition & 1 deletion py-polars/tests/io/test_lazy_ipc.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ def test_row_count(foods_ipc: str) -> None:
def test_is_in_type_coercion(foods_ipc: str) -> None:
out = (
pl.scan_ipc(foods_ipc)
.filter(pl.col("category").is_in(["vegetables"]))
.filter(pl.col("category").is_in(("vegetables", "ice cream")))
.collect()
)
assert out.shape == (7, 4)
Expand Down
2 changes: 1 addition & 1 deletion py-polars/tests/test_categorical.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,7 +155,7 @@ def test_categorical_is_in_list() -> None:
{"a": [1, 2, 3, 1, 2], "b": ["a", "b", "c", "d", "e"]}
).with_column(pl.col("b").cast(pl.Categorical))

cat_list = ["a", "b", "c"]
cat_list = ("a", "b", "c")
assert df.filter(pl.col("b").is_in(cat_list)).to_dict(False) == {
"a": [1, 2, 3],
"b": ["a", "b", "c"],
Expand Down
10 changes: 8 additions & 2 deletions py-polars/tests/test_datelike.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,8 +132,14 @@ def test_datetime_consistency() -> None:
df = pl.DataFrame({"date": [dt]})

assert df["date"].dt[0] == dt
assert df.select(pl.lit(dt))["literal"].dt[0] == dt
assert df.filter(pl.col("date") == dt).rows() == [(dt,)]

for date_literal in (
dt,
np.datetime64(dt, "us"),
np.datetime64(dt, "ns"),
):
assert df.select(pl.lit(date_literal))["literal"].dt[0] == dt
assert df.filter(pl.col("date") == date_literal).rows() == [(dt,)]

ddf = df.select(
[
Expand Down

0 comments on commit 0cbbf48

Please sign in to comment.