Skip to content

Commit

Permalink
feat(python): allow implicit None branch in when then otherwise (#5264)
Browse files Browse the repository at this point in the history
  • Loading branch information
ritchie46 committed Oct 19, 2022
1 parent 94d1fa9 commit 235b4bb
Show file tree
Hide file tree
Showing 6 changed files with 58 additions and 4 deletions.
4 changes: 3 additions & 1 deletion py-polars/polars/internals/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@
)
from polars.internals.lazyframe import LazyFrame, wrap_ldf
from polars.internals.series import Series, wrap_s
from polars.internals.whenthen import when # used in expr.clip()
from polars.internals.whenthen import WhenThen, WhenThenThen, when

__all__ = [
"DataFrame",
Expand Down Expand Up @@ -74,6 +74,8 @@
"wrap_expr",
"wrap_ldf",
"wrap_s",
"WhenThen",
"WhenThenThen",
"_deser_and_exec",
"_is_local_file",
"_prepare_file_arg",
Expand Down
5 changes: 4 additions & 1 deletion py-polars/polars/internals/dataframe/frame.py
Original file line number Diff line number Diff line change
Expand Up @@ -5214,7 +5214,10 @@ def lazy(self: DF) -> pli.LazyFrame:

def select(
self: DF,
exprs: str | pli.Expr | pli.Series | Sequence[str | pli.Expr | pli.Series],
exprs: str
| pli.Expr
| pli.Series
| Sequence[str | pli.Expr | pli.Series | pli.WhenThen | pli.WhenThenThen],
) -> DF:
"""
Select columns from this DataFrame.
Expand Down
18 changes: 17 additions & 1 deletion py-polars/polars/internals/expr/expr.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,18 @@ def selection_to_pyexpr_list(
exprs: str
| Expr
| pli.Series
| Sequence[str | Expr | pli.Series | timedelta | date | datetime | int | float],
| Sequence[
str
| Expr
| pli.Series
| timedelta
| date
| datetime
| int
| float
| pli.WhenThen
| pli.WhenThenThen
],
) -> list[PyExpr]:
if isinstance(exprs, (str, Expr, pli.Series)):
exprs = [exprs]
Expand All @@ -73,6 +84,8 @@ def expr_to_lit_or_expr(
| datetime
| time
| timedelta
| pli.WhenThen
| pli.WhenThenThen
| Sequence[(int | float | str | None)]
),
str_to_lit: bool = True,
Expand Down Expand Up @@ -104,6 +117,9 @@ def expr_to_lit_or_expr(
return expr
elif isinstance(expr, list):
return pli.lit(pli.Series("", [expr]))
elif isinstance(expr, (pli.WhenThen, pli.WhenThenThen)):
# implicitly add the null branch.
return expr.otherwise(None)
else:
raise ValueError(
f"did not expect value {expr} of type {type(expr)}, maybe disambiguate with"
Expand Down
5 changes: 4 additions & 1 deletion py-polars/polars/internals/lazyframe/frame.py
Original file line number Diff line number Diff line change
Expand Up @@ -1113,7 +1113,10 @@ def filter(self: LDF, predicate: pli.Expr | str | pli.Series | list[bool]) -> LD

def select(
self: LDF,
exprs: str | pli.Expr | pli.Series | Sequence[str | pli.Expr | pli.Series],
exprs: str
| pli.Expr
| pli.Series
| Sequence[str | pli.Expr | pli.Series | pli.WhenThen | pli.WhenThenThen],
) -> LDF:
"""
Select columns from this DataFrame.
Expand Down
11 changes: 11 additions & 0 deletions py-polars/polars/internals/whenthen.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from __future__ import annotations

import typing
from typing import Any, Sequence

try:
Expand Down Expand Up @@ -65,6 +66,11 @@ def otherwise(
expr = pli.expr_to_lit_or_expr(expr)
return pli.wrap_expr(self.pywhenthenthen.otherwise(expr._pyexpr))

@typing.no_type_check
def __getattr__(self, item) -> pli.Expr:
expr = self.otherwise(None) # noqa: F841
return eval(f"expr.{item}")


class WhenThen:
"""Utility class. See the `when` function."""
Expand All @@ -90,6 +96,11 @@ def otherwise(self, expr: pli.Expr | int | float | str | None) -> pli.Expr:
expr = pli.expr_to_lit_or_expr(expr)
return pli.wrap_expr(self._pywhenthen.otherwise(expr._pyexpr))

@typing.no_type_check
def __getattr__(self, item) -> pli.Expr:
expr = self.otherwise(None) # noqa: F841
return eval(f"expr.{item}")


class When:
"""Utility class. See the `when` function."""
Expand Down
19 changes: 19 additions & 0 deletions py-polars/tests/unit/test_predicates.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,3 +19,22 @@ def test_predicate_4906() -> None:
assert ldf.filter(
pl.min([(pl.col("dt") + one_day), date(2022, 9, 30)]) > date(2022, 9, 10)
).collect().to_dict(False) == {"dt": [date(2022, 9, 10), date(2022, 9, 20)]}


def test_when_then_implicit_none() -> None:
df = pl.DataFrame(
{
"team": ["A", "A", "A", "B", "B", "C"],
"points": [11, 8, 10, 6, 6, 5],
}
)

assert df.select(
[
pl.when(pl.col("points") > 7).then("Foo"),
pl.when(pl.col("points") > 7).then("Foo").alias("bar"),
]
).to_dict(False) == {
"literal": ["Foo", "Foo", "Foo", None, None, None],
"bar": ["Foo", "Foo", "Foo", None, None, None],
}

0 comments on commit 235b4bb

Please sign in to comment.