Skip to content

Commit

Permalink
reinstate old ternary behavior as experimental (#4244)
Browse files Browse the repository at this point in the history
  • Loading branch information
ritchie46 committed Aug 3, 2022
1 parent 6b0041a commit 99ed0d8
Show file tree
Hide file tree
Showing 3 changed files with 72 additions and 16 deletions.
20 changes: 13 additions & 7 deletions polars/polars-lazy/src/physical_plan/expressions/ternary.rs
Original file line number Diff line number Diff line change
Expand Up @@ -123,9 +123,12 @@ impl PhysicalExpr for TernaryExpr {
groups: &'a GroupsProxy,
state: &ExecutionState,
) -> Result<AggregationContext<'a>> {
if !self.predicate.is_valid_aggregation() {
let aggregation_predicate = self.predicate.is_valid_aggregation();
if !aggregation_predicate {
// unwrap will not fail as it is not an aggregation expression.
return Err(PolarsError::ComputeError(format!("the predicate '{}' in 'when->then->otherwise' is not a valid aggregation and might produce a different number of rows than the groupby operation would", self.predicate.as_expression().unwrap()).into()));
eprintln!(
"The predicate '{}' in 'when->then->otherwise' is not a valid aggregation and might produce a different number of rows than the groupby operation would. This behavior is experimental and may be subject to change", self.predicate.as_expression().unwrap()
)
}
let op_mask = || self.predicate.evaluate_on_groups(df, groups, state);
let op_truthy = || self.truthy.evaluate_on_groups(df, groups, state);
Expand All @@ -140,11 +143,6 @@ impl PhysicalExpr for TernaryExpr {

let mask_s = ac_mask.flat_naive();

assert!(
ac_truthy.can_combine(&ac_falsy),
"cannot combine this ternary expression, the groups do not match"
);

use AggState::*;
match (ac_truthy.agg_state(), ac_falsy.agg_state()) {
// all branches are aggregated-flat or literal
Expand Down Expand Up @@ -175,6 +173,10 @@ impl PhysicalExpr for TernaryExpr {
// otherwise:
// None
(AggregatedList(_), Literal(_)) | (Literal(_), AggregatedList(_)) => {
if !aggregation_predicate {
// experimental elementwise behavior tested in `test_binary_agg_context_1`
return finish_as_iters(ac_truthy, ac_falsy, ac_mask);
}
let mask = mask_s.bool()?;
let check_length = |ca: &ListChunked, mask: &BooleanChunked| {
if ca.len() != mask.len() {
Expand Down Expand Up @@ -262,6 +264,10 @@ impl PhysicalExpr for TernaryExpr {
// Both are or a flat series or aggregated into a list
// so we can flatten the Series an apply the operators
_ => {
if !aggregation_predicate {
// experimental elementwise behavior tested in `test_binary_agg_context_1`
return finish_as_iters(ac_truthy, ac_falsy, ac_mask);
}
let mut mask = mask_s.bool()?.clone();
let mut truthy = ac_truthy.flat_naive().into_owned();
let mut falsy = ac_falsy.flat_naive().into_owned();
Expand Down
59 changes: 59 additions & 0 deletions polars/polars-lazy/src/tests/aggregations.rs
Original file line number Diff line number Diff line change
Expand Up @@ -288,6 +288,65 @@ fn test_binary_agg_context_0() -> Result<()> {
Ok(())
}

// just like binary expression, this must be changed. This can work
#[test]
fn test_binary_agg_context_1() -> Result<()> {
let df = df![
"groups" => [1, 1, 2, 2, 3, 3],
"vals" => [1, 13, 3, 87, 1, 6]
]?;

// groups
// 1 => [1, 13]
// 2 => [3, 87]
// 3 => [1, 6]

let out = df
.clone()
.lazy()
.groupby_stable([col("groups")])
.agg([when(col("vals").eq(lit(1)))
.then(col("vals").sum())
.otherwise(lit(90))
.alias("vals")])
.collect()?;

// if vals == 1 then sum(vals) else vals
// [14, 90]
// [90, 90]
// [7, 90]
let out = out.column("vals")?;
let out = out.explode()?;
let out = out.i32()?;
assert_eq!(
Vec::from(out),
&[Some(14), Some(90), Some(90), Some(90), Some(7), Some(90)]
);

let out = df
.lazy()
.groupby_stable([col("groups")])
.agg([when(col("vals").eq(lit(1)))
.then(lit(90))
.otherwise(col("vals").sum())
.alias("vals")])
.collect()?;

// if vals == 1 then 90 else sum(vals)
// [90, 14]
// [90, 90]
// [90, 7]
let out = out.column("vals")?;
let out = out.explode()?;
let out = out.i32()?;
assert_eq!(
Vec::from(out),
&[Some(90), Some(14), Some(90), Some(90), Some(90), Some(7)]
);

Ok(())
}

#[test]
fn test_binary_agg_context_2() -> Result<()> {
let df = df![
Expand Down
9 changes: 0 additions & 9 deletions py-polars/tests/test_errors.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,12 +140,3 @@ def test_getitem_errs() -> None:

with pytest.raises(ValueError, match="Cannot __setitem__ on DataFrame with key:.*"):
df[{"some"}] = "foo"


def test_invalid_predication_ternary() -> None:
df = pl.DataFrame({"name": ["a", "b", "a", "b"], "value": [1, 3, 2, 1]})

with pytest.raises(pl.ComputeError):
df.groupby("name").agg(
pl.when(pl.col("value") > 2).then(pl.col("value").rank()).otherwise(None)
)

0 comments on commit 99ed0d8

Please sign in to comment.