Skip to content

Commit

Permalink
fix[rust]: don't unwrap if lp has no input (#4909)
Browse files Browse the repository at this point in the history
  • Loading branch information
ritchie46 committed Sep 20, 2022
1 parent 8615047 commit 719cc28
Show file tree
Hide file tree
Showing 4 changed files with 32 additions and 4 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -303,7 +303,7 @@ fn string_addition_to_linear_concat(
) -> Option<AExpr> {
{
let lp = lp_arena.get(lp_node);
let input = lp.get_input().unwrap();
let input = lp.get_input()?;
let schema = lp_arena.get(input).schema(lp_arena);

let get_type = |ae: &AExpr| ae.get_type(&schema, Context::Default, expr_arena).ok();
Expand Down
5 changes: 4 additions & 1 deletion py-polars/polars/internals/expr/expr.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,10 @@


def selection_to_pyexpr_list(
exprs: str | Expr | pli.Series | Sequence[str | Expr | pli.Series],
exprs: str
| Expr
| pli.Series
| Sequence[str | Expr | pli.Series | date | datetime | int | float],
) -> list[PyExpr]:
if isinstance(exprs, (str, Expr, pli.Series)):
exprs = [exprs]
Expand Down
8 changes: 6 additions & 2 deletions py-polars/polars/internals/lazy_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -346,7 +346,9 @@ def max(column: str | Sequence[pli.Expr | str] | pli.Series) -> pli.Expr | Any:


@overload
def min(column: str | Sequence[pli.Expr | str]) -> pli.Expr:
def min(
column: str | Sequence[pli.Expr | str | date | datetime | int | float],
) -> pli.Expr:
...


Expand All @@ -355,7 +357,9 @@ def min(column: pli.Series) -> int | float:
...


def min(column: str | Sequence[pli.Expr | str] | pli.Series) -> pli.Expr | Any:
def min(
column: str | Sequence[pli.Expr | str | date | datetime | int | float] | pli.Series,
) -> pli.Expr | Any:
"""
Get the minimum value.
Expand Down
21 changes: 21 additions & 0 deletions py-polars/tests/unit/test_predicates.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
from datetime import date, timedelta

import polars as pl


def test_predicate_4906() -> None:
one_day = timedelta(days=1)

ldf = pl.DataFrame(
{
"dt": [
date(2022, 9, 1),
date(2022, 9, 10),
date(2022, 9, 20),
]
}
).lazy()

assert ldf.filter(
pl.min([(pl.col("dt") + one_day), date(2022, 9, 30)]) > date(2022, 9, 10)
).collect().to_dict(False) == {"dt": [date(2022, 9, 10), date(2022, 9, 20)]}

0 comments on commit 719cc28

Please sign in to comment.