Skip to content

Commit

Permalink
allow literal as aggregation (#3722)
Browse files Browse the repository at this point in the history
* allow literal as aggregation

* auto explode product
  • Loading branch information
ritchie46 committed Jun 17, 2022
1 parent 174a213 commit 99f5f0f
Show file tree
Hide file tree
Showing 5 changed files with 43 additions and 4 deletions.
11 changes: 9 additions & 2 deletions polars/polars-lazy/src/dsl/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -955,7 +955,14 @@ impl Expr {
/// Get the product aggregation of an expression
#[cfg_attr(docsrs, doc(cfg(feature = "product")))]
pub fn product(self) -> Self {
self.apply(
let options = FunctionOptions {
collect_groups: ApplyOptions::ApplyGroups,
input_wildcard_expansion: false,
auto_explode: true,
fmt_str: "product",
};

self.function_with_options(
move |s: Series| Ok(s.product()),
GetOutput::map_dtype(|dt| {
use DataType::*;
Expand All @@ -965,8 +972,8 @@ impl Expr {
_ => Int64,
}
}),
options,
)
.with_fmt("product")
}

/// Fill missing value with next non-null.
Expand Down
2 changes: 1 addition & 1 deletion polars/polars-lazy/src/physical_plan/executors/groupby.rs
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ pub(super) fn groupby_helper(
let get_agg = || aggs
.par_iter()
.map(|expr| {
let agg = expr.evaluate_on_groups(&df, groups, state)?.aggregated();
let agg = expr.evaluate_on_groups(&df, groups, state)?.finalize();
if agg.len() != groups.len() {
return Err(PolarsError::ComputeError(
format!("returned aggregation is a different length: {} than the group lengths: {}",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ impl Executor for GroupByDynamicExec {
self.aggs
.par_iter()
.map(|expr| {
let agg = expr.evaluate_on_groups(&df, groups, state)?.aggregated();
let agg = expr.evaluate_on_groups(&df, groups, state)?.finalize();
if agg.len() != groups.len() {
return Err(PolarsError::ComputeError(
format!("returned aggregation is a different length: {} than the group lengths: {}",
Expand Down
15 changes: 15 additions & 0 deletions polars/polars-lazy/src/physical_plan/expressions/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -417,6 +417,21 @@ impl<'a> AggregationContext<'a> {
}
}

/// Get the final aggregated version of the series.
pub(crate) fn finalize(&mut self) -> Series {
// we clone, because we only want to call `self.groups()` if needed.
// self groups may instantiate new groups and thus can be expensive.
match &self.state {
AggState::Literal(s) => {
let s = s.clone();
self.groups();
let rows = self.groups.len();
s.expand_at_index(0, rows)
}
_ => self.aggregated(),
}
}

/// Different from aggregated, in arity operations we expect literals to expand to the size of the
/// group
/// eg:
Expand Down
17 changes: 17 additions & 0 deletions polars/tests/it/lazy/groupby.rs
Original file line number Diff line number Diff line change
Expand Up @@ -109,3 +109,20 @@ fn test_filter_diff_arithmetic() -> Result<()> {

Ok(())
}

#[test]
fn test_groupby_lit_agg() -> Result<()> {
let df = df![
"group" => [1, 2, 1, 1, 2],
]?;

let out = df
.lazy()
.groupby([col("group")])
.agg([lit("foo").alias("foo")])
.collect()?;

assert_eq!(out.column("foo")?.dtype(), &DataType::Utf8);

Ok(())
}

0 comments on commit 99f5f0f

Please sign in to comment.