Skip to content

Commit

Permalink
improve when then otherwise for lists (#3614)
Browse files Browse the repository at this point in the history
  • Loading branch information
ritchie46 committed Jun 7, 2022
1 parent 938a9be commit 11a6b56
Show file tree
Hide file tree
Showing 7 changed files with 91 additions and 9 deletions.
5 changes: 5 additions & 0 deletions polars/polars-core/src/datatypes.rs
Original file line number Diff line number Diff line change
Expand Up @@ -733,6 +733,11 @@ impl PartialEq for DataType {
#[cfg(feature = "dtype-categorical")]
(Categorical(_), Categorical(_)) => true,
(Datetime(tu_l, tz_l), Datetime(tu_r, tz_r)) => tu_l == tu_r && tz_l == tz_r,
(List(left_inner), List(right_inner)) => left_inner == right_inner,
#[cfg(feature = "dtype-duration")]
(Duration(tu_l), Duration(tu_r)) => tu_l == tu_r,
#[cfg(feature = "object")]
(Object(lhs), Object(rhs)) => lhs == rhs,
_ => std::mem::discriminant(self) == std::mem::discriminant(other),
}
}
Expand Down
14 changes: 11 additions & 3 deletions polars/polars-lazy/src/logical_plan/optimizer/type_coercion.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ pub struct TypeCoercionRule {}

/// determine if we use the supertype or not. For instance when we have a column Int64 and we compare with literal UInt32
/// it would be wasteful to cast the column instead of the literal.
fn use_supertype(
fn modify_supertype(
mut st: DataType,
left: &AExpr,
right: &AExpr,
Expand Down Expand Up @@ -67,6 +67,14 @@ fn use_supertype(
| (Utf8, Categorical(_), AExpr::Literal(_), _) => {
st = DataType::Categorical(None);
}
// when then expression literals can have a different list type.
// so we cast the literal to the other hand side.
(List(inner), List(other), _, AExpr::Literal(_))
| (List(other), List(inner), AExpr::Literal(_), _)
if inner != other =>
{
st = DataType::List(inner.clone())
}
// do nothing
_ => {}
}
Expand Down Expand Up @@ -117,7 +125,7 @@ impl OptimizationRule for TypeCoercionRule {
None
} else {
let st = get_supertype(&type_true, &type_false).expect("supertype");
let st = use_supertype(st, truthy, falsy, &type_true, &type_false);
let st = modify_supertype(st, truthy, falsy, &type_true, &type_false);

// only cast if the type is not already the super type.
// this can prevent an expensive flattening and subsequent aggregation
Expand Down Expand Up @@ -270,7 +278,7 @@ impl OptimizationRule for TypeCoercionRule {
let st = get_supertype(&type_left, &type_right)
.expect("could not find supertype of binary expr");

let mut st = use_supertype(st, left, right, &type_left, &type_right);
let mut st = modify_supertype(st, left, right, &type_left, &type_right);

#[allow(unused_mut, unused_assignments)]
let mut cat_str_arithmetic = false;
Expand Down
1 change: 1 addition & 0 deletions polars/polars-lazy/src/tests/projection_queries.rs
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,7 @@ fn scan_join_same_file() -> Result<()> {
}

#[test]
#[cfg(all(feature = "regex", feature = "concat_str"))]
fn concat_str_regex_expansion() -> Result<()> {
let df = df![
"a"=> [1, 1, 1],
Expand Down
20 changes: 19 additions & 1 deletion py-polars/polars/internals/expr.py
Original file line number Diff line number Diff line change
Expand Up @@ -5286,7 +5286,23 @@ def nanoseconds(self) -> Expr:


def expr_to_lit_or_expr(
expr: Union[Expr, bool, int, float, str, "pli.Series", None],
expr: Union[
Expr,
bool,
int,
float,
str,
"pli.Series",
None,
Sequence[
Union[
int,
float,
str,
None,
]
],
],
str_to_lit: bool = True,
) -> Expr:
"""
Expand All @@ -5312,6 +5328,8 @@ def expr_to_lit_or_expr(
return pli.lit(expr)
elif isinstance(expr, Expr):
return expr
elif isinstance(expr, list):
return pli.lit(pli.Series("", [expr]))
else:
raise ValueError(
f"did not expect value {expr} of type {type(expr)}, maybe disambiguate with pl.lit or pl.col"
Expand Down
47 changes: 43 additions & 4 deletions py-polars/polars/internals/whenthen.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Any, Union
from typing import Any, Sequence, Union

try:
from polars.polars import when as pywhen
Expand All @@ -24,7 +24,24 @@ def when(self, predicate: pli.Expr) -> "WhenThenThen":
"""
return WhenThenThen(self.pywhenthenthen.when(predicate._pyexpr))

def then(self, expr: Union[pli.Expr, int, float, str, None]) -> "WhenThenThen":
def then(
self,
expr: Union[
pli.Expr,
int,
float,
str,
None,
Sequence[
Union[
int,
float,
str,
None,
]
],
],
) -> "WhenThenThen":
"""
Values to return in case of the predicate being `True`.
Expand All @@ -33,7 +50,24 @@ def then(self, expr: Union[pli.Expr, int, float, str, None]) -> "WhenThenThen":
expr_ = pli.expr_to_lit_or_expr(expr)
return WhenThenThen(self.pywhenthenthen.then(expr_._pyexpr))

def otherwise(self, expr: Union[pli.Expr, int, float, str, None]) -> pli.Expr:
def otherwise(
self,
expr: Union[
pli.Expr,
int,
float,
str,
None,
Sequence[
Union[
int,
float,
str,
None,
]
],
],
) -> pli.Expr:
"""
Values to return in case of the predicate being `False`.
Expand Down Expand Up @@ -75,7 +109,12 @@ class When:
def __init__(self, pywhen: "pywhen"):
self._pywhen = pywhen

def then(self, expr: Union[pli.Expr, int, float, str, None]) -> WhenThen:
def then(
self,
expr: Union[
pli.Expr, int, float, str, None, Sequence[Union[None, int, float, str]]
],
) -> WhenThen:
"""
Values to return in case of the predicate being `True`.
Expand Down
11 changes: 11 additions & 0 deletions py-polars/tests/test_lists.py
Original file line number Diff line number Diff line change
Expand Up @@ -252,3 +252,14 @@ def test_list_fill_null() -> None:
.alias("C")
]
).to_series().to_list() == [["a", "b", "c"], None, None, ["d", "e"]]


def test_list_fill_list() -> None:
assert pl.DataFrame({"a": [[1, 2, 3], []]}).select(
[
pl.when(pl.col("a").arr.lengths() == 0)
.then([5])
.otherwise(pl.col("a"))
.alias("filled")
]
).to_dict(False) == {"filled": [[1, 2, 3], [5]]}
2 changes: 1 addition & 1 deletion py-polars/tests/test_queries.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ def test_repeat_expansion_in_groupby() -> None:
pl.DataFrame({"g": [1, 2, 2, 3, 3, 3]})
.groupby("g", maintain_order=True)
.agg(pl.repeat(1, pl.count()).cumsum())
.to_dict()
.to_dict(False)
)
assert out == {"g": [1, 2, 3], "literal": [[1], [1, 2], [1, 2, 3]]}

Expand Down

0 comments on commit 11a6b56

Please sign in to comment.