Skip to content

Commit

Permalink
fix(rust, python): fix aggregation that filters out all data (#6036)
Browse files Browse the repository at this point in the history
  • Loading branch information
ritchie46 committed Jan 4, 2023
1 parent 78f3859 commit fde4e9c
Show file tree
Hide file tree
Showing 3 changed files with 41 additions and 7 deletions.
12 changes: 5 additions & 7 deletions polars/polars-lazy/src/physical_plan/expressions/apply.rs
Original file line number Diff line number Diff line change
Expand Up @@ -140,26 +140,24 @@ impl PhysicalExpr for ApplyExpr {
return Err(expression_err!(msg, self.expr, ComputeError));
}

let name = s.name().to_string();
let agg = ac.aggregated();
// collection of empty list leads to a null dtype
// see: #3687
if s.len() == 0 {
if agg.len() == 0 {
// create input for the function to determine the output dtype
// see #3946
let agg = ac.aggregated();
let agg = agg.list().unwrap();
let input_dtype = agg.inner_dtype();

let input = Series::full_null("", 0, &input_dtype);

let output = self.function.call_udf(&mut [input])?;
let ca = ListChunked::full(ac.series().name(), &output, 0);
let ca = ListChunked::full(&name, &output, 0);
return Ok(self.finish_apply_groups(ac, ca));
}

let name = s.name().to_string();

let mut ca: ListChunked = ac
.aggregated()
let mut ca: ListChunked = agg
.list()
.unwrap()
.par_iter()
Expand Down
21 changes: 21 additions & 0 deletions polars/polars-lazy/src/tests/queries.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2009,3 +2009,24 @@ fn test_partitioned_gb_ternary() -> PolarsResult<()> {

Ok(())
}

#[test]
fn test_foo() -> PolarsResult<()> {
let df = df![
"a" => [2, 2],
"b" => [1, 2]
]?;

let out = df
.lazy()
.groupby([col("a")])
.agg([
(col("a").filter(col("b").eq(0)).diff(1, Default::default()) * lit(100))
.diff(1, Default::default())
.alias("foo"),
])
.collect();
dbg!(out);

Ok(())
}
15 changes: 15 additions & 0 deletions py-polars/tests/unit/test_aggregations.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,3 +60,18 @@ def test_hmean_with_str_column() -> None:
assert pl.DataFrame(
{"int": [1, 2, 3], "bool": [True, True, None], "str": ["a", "b", "c"]}
).mean(axis=1).to_list() == [1.0, 1.5, 3.0]


def test_list_aggregation_that_filters_all_data_6017() -> None:
out = (
pl.DataFrame({"col_to_groupby": [2], "flt": [1672740910.967138], "col3": [1]})
.groupby("col_to_groupby")
.agg(
(pl.col("flt").filter(pl.col("col3") == 0).diff() * 1000)
.diff()
.alias("calc")
)
)

assert out.schema == {"col_to_groupby": pl.Int64, "calc": pl.List(pl.Float64)}
assert out.to_dict(False) == {"col_to_groupby": [2], "calc": [[]]}

0 comments on commit fde4e9c

Please sign in to comment.