Skip to content

Commit

Permalink
chained when_then operation
Browse files Browse the repository at this point in the history
  • Loading branch information
ritchie46 committed Jun 4, 2021
1 parent e074101 commit cdcc37c
Show file tree
Hide file tree
Showing 4 changed files with 187 additions and 0 deletions.
96 changes: 96 additions & 0 deletions polars/polars-lazy/src/dsl.rs
Original file line number Diff line number Diff line change
Expand Up @@ -385,15 +385,31 @@ pub fn binary_expr(l: Expr, op: Operator, r: Expr) -> Expr {
}
}

/// Intermediate state of `when(..).then(..).otherwise(..)` expr.
pub struct When {
predicate: Expr,
}

/// Intermediate state of `when(..).then(..).otherwise(..)` expr.
pub struct WhenThen {
predicate: Expr,
then: Expr,
}

/// Intermediate state of chain when then exprs.
///
/// ```ignore
/// when(..).then(..)
/// when(..).then(..)
/// when(..).then(..)
/// .otherwise(..)`
/// ```
#[derive(Clone)]
pub struct WhenThenThen {
predicates: Vec<Expr>,
thens: Vec<Expr>,
}

impl When {
pub fn then(self, expr: Expr) -> WhenThen {
WhenThen {
Expand All @@ -404,6 +420,13 @@ impl When {
}

impl WhenThen {
pub fn when(self, predicate: Expr) -> WhenThenThen {
WhenThenThen {
predicates: vec![self.predicate, predicate],
thens: vec![self.then],
}
}

pub fn otherwise(self, expr: Expr) -> Expr {
Expr::Ternary {
predicate: Box::new(self.predicate),
Expand All @@ -413,6 +436,68 @@ impl WhenThen {
}
}

impl WhenThenThen {
pub fn then(mut self, expr: Expr) -> Self {
self.thens.push(expr);
self
}

pub fn when(mut self, predicate: Expr) -> Self {
self.predicates.push(predicate);
self
}

pub fn otherwise(self, expr: Expr) -> Expr {
// we iterate the preds/ exprs last in first out
// and nest them.
//
// // this expr:
// when((col('x') == 'a')).then(1)
// .when(col('x') == 'a').then(2)
// .when(col('x') == 'b').then(3)
// .otherwise(4)
//
// needs to become:
// when((col('x') == 'a')).then(1) -
// .otherwise( |
// when(col('x') == 'a').then(2) - |
// .otherwise( | |
// pl.when(col('x') == 'b').then(3) | |
// .otherwise(4) | inner | outer
// ) | |
// ) _| _|
//
// by iterating lifo we first create
// `inner` and then assighn that to `otherwise`,
// which will be used in the next layer `outer`
//

let pred_iter = self.predicates.into_iter().rev();
let mut then_iter = self.thens.into_iter().rev();

let mut otherwise = expr;

for e in pred_iter {
otherwise = Expr::Ternary {
predicate: Box::new(e),
truthy: Box::new(
then_iter
.next()
.expect("expr expected, did you call when().then().otherwise?"),
),
falsy: Box::new(otherwise),
}
}
if then_iter.next().is_some() {
panic!(
"this expr is not properly constructed. \
Every `when` should have an accompanied `then` call."
)
}
otherwise
}
}

/// Start a when-then-otherwise expression
pub fn when(predicate: Expr) -> When {
When { predicate }
Expand Down Expand Up @@ -1301,4 +1386,15 @@ mod test {
);
Ok(())
}

#[test]
fn test_when_then_when_then() {
let e = when(col("a"))
.then(col("b"))
.when(col("c"))
.then(col("d"))
.otherwise(col("f"));

dbg!(e);
}
}
50 changes: 50 additions & 0 deletions py-polars/polars/lazy/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -1679,6 +1679,39 @@ def expr_to_lit_or_expr(
return expr


class WhenThenThen:
"""
Utility class. See the `when` function.
"""

def __init__(self, pywhenthenthen):
self.pywenthenthen = pywhenthenthen

def when(self, predicate: "Expr") -> "WhenThenThen":
"""
start another when, then, otherwise layer
"""
return WhenThenThen(self.pywenthenthen.when(predicate._pyexpr))

def then(self, expr: "Union[Expr, int, float, str]") -> "WhenThenThen":
"""
Values to return in case of the predicate being `True`
See Also: the `when` function.
"""
expr = expr_to_lit_or_expr(expr)
return WhenThenThen(self.pywenthenthen.then(expr._pyexpr))

def otherwise(self, expr: "Union[Expr, int, float, str]") -> "Expr":
"""
Values to return in case of the predicate being `False`
See Also: the `when` function.
"""
expr = expr_to_lit_or_expr(expr)
return wrap_expr(self.pywenthenthen.otherwise(expr._pyexpr))


class WhenThen:
"""
Utility class. See the `when` function.
Expand All @@ -1687,6 +1720,12 @@ class WhenThen:
def __init__(self, pywhenthen: "PyWhenThen"): # noqa F821
self._pywhenthen = pywhenthen

def when(self, predicate: "Expr"):
"""
start another when, then, otherwise layer
"""
return WhenThenThen(self._pywhenthen.when(predicate._pyexpr))

def otherwise(self, expr: "Union[Expr, int, float, str]") -> "Expr":
"""
Values to return in case of the predicate being `False`
Expand Down Expand Up @@ -1731,6 +1770,17 @@ def when(expr: "Expr") -> When:
.otherwise(lit(-1))
)
```
Or with multiple `when, thens` chained:
```python
lf.with_column(
when(col("foo") > 2).then(1)
when(col("bar") > 2).then(4)
.otherwise(-1)
)
```
"""
expr = expr_to_lit_or_expr(expr)
pw = pywhen(expr._pyexpr)
Expand Down
29 changes: 29 additions & 0 deletions py-polars/src/lazy/dsl.rs
Original file line number Diff line number Diff line change
Expand Up @@ -434,6 +434,13 @@ pub struct WhenThen {
then: PyExpr,
}

#[pyclass]
#[derive(Clone)]
pub struct WhenThenThen {
inner: dsl::WhenThenThen,
}


#[pymethods]
impl When {
pub fn then(&self, expr: PyExpr) -> WhenThen {
Expand All @@ -446,6 +453,15 @@ impl When {

#[pymethods]
impl WhenThen {
pub fn when(&self, predicate: PyExpr) -> WhenThenThen {
let e = dsl::when(self.predicate.inner.clone())
.then(self.then.inner.clone())
.when(predicate.inner);
WhenThenThen {
inner: e
}
}

pub fn otherwise(&self, expr: PyExpr) -> PyExpr {
dsl::ternary_expr(
self.predicate.inner.clone(),
Expand All @@ -456,6 +472,19 @@ impl WhenThen {
}
}

#[pymethods]
impl WhenThenThen {
pub fn when(&self, predicate: PyExpr) -> WhenThenThen {
Self { inner: self.inner.clone().when(predicate.inner) }
}
pub fn then(&self, expr: PyExpr) -> WhenThenThen {
Self { inner: self.inner.clone().then(expr.inner) }
}
pub fn otherwise(&self, expr: PyExpr) -> PyExpr {
self.inner.clone().otherwise(expr.clone().inner).into()
}
}

pub fn when(predicate: PyExpr) -> When {
When { predicate }
}
Expand Down
12 changes: 12 additions & 0 deletions py-polars/tests/test_lazy.py
Original file line number Diff line number Diff line change
Expand Up @@ -177,3 +177,15 @@ def test_window_function():

out = df[[pl.first("B").over(["fruits", "cars"])]]
assert out["B_first"] == [5, 4, 3, 3, 5]


def test_when_then_flatten():
df = pl.DataFrame({"foo": [1, 2, 3], "bar": [3, 4, 5]})

assert df[
when(col("foo") > 1)
.then(col("bar"))
.when(col("bar") < 3)
.then(10)
.otherwise(30)
]["bar"] == [30, 4, 5]

0 comments on commit cdcc37c

Please sign in to comment.