Skip to content

Commit

Permalink
don't rename aggregations (#2603)
Browse files Browse the repository at this point in the history
* don't rename aggregations
  • Loading branch information
ritchie46 committed Feb 11, 2022
1 parent b8e991f commit 1e8480f
Show file tree
Hide file tree
Showing 11 changed files with 191 additions and 219 deletions.
137 changes: 30 additions & 107 deletions polars/polars-lazy/src/logical_plan/aexpr.rs
Original file line number Diff line number Diff line change
@@ -1,8 +1,6 @@
use crate::logical_plan::Context;
use crate::prelude::*;
use crate::utils::rename_field;
use polars_arrow::prelude::QuantileInterpolOptions;
use polars_core::frame::groupby::{fmt_groupby_column, GroupByMethod};
use polars_core::prelude::*;
use polars_core::utils::{get_supertype, get_time_units};
use polars_utils::arena::{Arena, Node};
Expand Down Expand Up @@ -214,115 +212,58 @@ impl AExpr {
Filter { input, .. } => arena.get(*input).to_field(schema, ctxt, arena),
Agg(agg) => {
use AAggExpr::*;
let field = match agg {
Min(expr) => field_by_context(
arena.get(*expr).to_field(schema, ctxt, arena)?,
ctxt,
GroupByMethod::Min,
),
Max(expr) => field_by_context(
arena.get(*expr).to_field(schema, ctxt, arena)?,
ctxt,
GroupByMethod::Max,
),
match agg {
Max(expr) | Sum(expr) | Min(expr) | First(expr) | Last(expr) => {
arena.get(*expr).to_field(schema, ctxt, arena)
}
Median(expr) => {
let mut field = field_by_context(
arena.get(*expr).to_field(schema, ctxt, arena)?,
ctxt,
GroupByMethod::Median,
);
let mut field = arena.get(*expr).to_field(schema, ctxt, arena)?;
if field.data_type() != &DataType::Utf8 {
field.coerce(DataType::Float64);
}
field
Ok(field)
}
Mean(expr) => {
let mut field = field_by_context(
arena.get(*expr).to_field(schema, ctxt, arena)?,
ctxt,
GroupByMethod::Mean,
);
let mut field = arena.get(*expr).to_field(schema, ctxt, arena)?;
field.coerce(DataType::Float64);
field
Ok(field)
}
List(expr) => {
let mut field = arena.get(*expr).to_field(schema, ctxt, arena)?;
field.coerce(DataType::List(field.data_type().clone().into()));
Ok(field)
}
First(expr) => field_by_context(
arena.get(*expr).to_field(schema, ctxt, arena)?,
ctxt,
GroupByMethod::First,
),
Last(expr) => field_by_context(
arena.get(*expr).to_field(schema, ctxt, arena)?,
ctxt,
GroupByMethod::Last,
),
List(expr) => field_by_context(
arena.get(*expr).to_field(schema, ctxt, arena)?,
ctxt,
GroupByMethod::List,
),
Std(expr) => {
let field = arena.get(*expr).to_field(schema, ctxt, arena)?;
let field = Field::new(field.name(), DataType::Float64);
let mut field = field_by_context(field, ctxt, GroupByMethod::Std);
let mut field = arena.get(*expr).to_field(schema, ctxt, arena)?;
field.coerce(DataType::Float64);
field
Ok(field)
}
Var(expr) => {
let field = arena.get(*expr).to_field(schema, ctxt, arena)?;
let field = Field::new(field.name(), DataType::Float64);
let mut field = field_by_context(field, ctxt, GroupByMethod::Var);
let mut field = arena.get(*expr).to_field(schema, ctxt, arena)?;
field.coerce(DataType::Float64);
field
Ok(field)
}
NUnique(expr) => {
let field = arena.get(*expr).to_field(schema, ctxt, arena)?;
let field = Field::new(field.name(), DataType::UInt32);
match ctxt {
Context::Default => field,
Context::Aggregation => {
let new_name =
fmt_groupby_column(field.name(), GroupByMethod::NUnique);
rename_field(&field, &new_name)
}
}
let mut field = arena.get(*expr).to_field(schema, ctxt, arena)?;
field.coerce(DataType::UInt32);
Ok(field)
}
Sum(expr) => field_by_context(
arena.get(*expr).to_field(schema, ctxt, arena)?,
ctxt,
GroupByMethod::Sum,
),
Count(expr) => {
let field = arena.get(*expr).to_field(schema, ctxt, arena)?;
let field = Field::new(field.name(), DataType::UInt32);
match ctxt {
Context::Default => field,
Context::Aggregation => {
let new_name =
fmt_groupby_column(field.name(), GroupByMethod::Count);
rename_field(&field, &new_name)
}
}
let mut field = arena.get(*expr).to_field(schema, ctxt, arena)?;
field.coerce(DataType::UInt32);
Ok(field)
}
AggGroups(expr) => {
let field = arena.get(*expr).to_field(schema, ctxt, arena)?;
let new_name = fmt_groupby_column(field.name(), GroupByMethod::Groups);
Field::new(&new_name, DataType::List(DataType::UInt32.into()))
let mut field = arena.get(*expr).to_field(schema, ctxt, arena)?;
field.coerce(DataType::List(DataType::UInt32.into()));
Ok(field)
}
Quantile {
expr,
quantile,
interpol,
} => {
let mut field = field_by_context(
arena.get(*expr).to_field(schema, ctxt, arena)?,
ctxt,
GroupByMethod::Quantile(*quantile, *interpol),
);
Quantile { expr, .. } => {
let mut field = arena.get(*expr).to_field(schema, ctxt, arena)?;
field.coerce(DataType::Float64);
field
Ok(field)
}
};
Ok(field)
}
}
Cast {
expr, data_type, ..
Expand Down Expand Up @@ -358,21 +299,3 @@ impl AExpr {
}
}
}

pub(crate) fn field_by_context(
mut field: Field,
ctxt: Context,
groupby_method: GroupByMethod,
) -> Field {
if &DataType::Boolean == field.data_type() {
field = Field::new(field.name(), DataType::UInt32)
}

match ctxt {
Context::Default => field,
Context::Aggregation => {
let new_name = fmt_groupby_column(field.name(), groupby_method);
rename_field(&field, &new_name)
}
}
}
2 changes: 1 addition & 1 deletion polars/polars-lazy/src/logical_plan/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -287,7 +287,7 @@ mod test {
.agg([col("sepal.width").min()])
.logical_plan;
println!("{:#?}", lp.schema().fields());
assert!(lp.schema().field_with_name("sepal.width_min").is_ok());
assert!(lp.schema().field_with_name("sepal.width").is_ok());
}

#[test]
Expand Down
57 changes: 24 additions & 33 deletions polars/polars-lazy/src/physical_plan/expressions/aggregation.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ use crate::physical_plan::PhysicalAggregation;
use crate::prelude::*;
use polars_arrow::export::arrow::{array::*, compute::concatenate::concatenate};
use polars_arrow::prelude::QuantileInterpolOptions;
use polars_core::frame::groupby::{fmt_groupby_column, GroupByMethod, GroupsProxy};
use polars_core::frame::groupby::{GroupByMethod, GroupsProxy};
use polars_core::{prelude::*, POOL};
use std::borrow::Cow;
use std::sync::Arc;
Expand Down Expand Up @@ -41,9 +41,7 @@ impl PhysicalExpr for AggregationExpr {
}

fn to_field(&self, input_schema: &Schema) -> Result<Field> {
let field = self.expr.to_field(input_schema)?;
let new_name = fmt_groupby_column(field.name(), self.agg_type);
Ok(Field::new(&new_name, field.data_type().clone()))
self.expr.to_field(input_schema)
}

fn as_agg_expr(&self) -> Result<&dyn PhysicalAggregation> {
Expand All @@ -66,68 +64,69 @@ impl PhysicalAggregation for AggregationExpr {
state: &ExecutionState,
) -> Result<Option<Series>> {
let mut ac = self.expr.evaluate_on_groups(df, groups, state)?;
let new_name = fmt_groupby_column(ac.series().name(), self.agg_type);
// don't change names by aggregations as is done in polars-core
let keep_name = ac.series().name().to_string();

match self.agg_type {
GroupByMethod::Min => {
let agg_s = ac.flat_naive().into_owned().agg_min(ac.groups());
Ok(rename_option_series(agg_s, &new_name))
Ok(rename_option_series(agg_s, &keep_name))
}
GroupByMethod::Max => {
let agg_s = ac.flat_naive().into_owned().agg_max(ac.groups());
Ok(rename_option_series(agg_s, &new_name))
Ok(rename_option_series(agg_s, &keep_name))
}
GroupByMethod::Median => {
let agg_s = ac.flat_naive().into_owned().agg_median(ac.groups());
Ok(rename_option_series(agg_s, &new_name))
Ok(rename_option_series(agg_s, &keep_name))
}
GroupByMethod::Mean => {
let agg_s = ac.flat_naive().into_owned().agg_mean(ac.groups());
Ok(rename_option_series(agg_s, &new_name))
Ok(rename_option_series(agg_s, &keep_name))
}
GroupByMethod::Sum => {
let agg_s = ac.flat_naive().into_owned().agg_sum(ac.groups());
Ok(rename_option_series(agg_s, &new_name))
Ok(rename_option_series(agg_s, &keep_name))
}
GroupByMethod::Count => {
let mut ca = ac.groups.group_count();
ca.rename(&new_name);
ca.rename(&keep_name);
Ok(Some(ca.into_series()))
}
GroupByMethod::First => {
let mut agg_s = ac.flat_naive().into_owned().agg_first(ac.groups());
agg_s.rename(&new_name);
agg_s.rename(&keep_name);
Ok(Some(agg_s))
}
GroupByMethod::Last => {
let mut agg_s = ac.flat_naive().into_owned().agg_last(ac.groups());
agg_s.rename(&new_name);
agg_s.rename(&keep_name);
Ok(Some(agg_s))
}
GroupByMethod::NUnique => {
let opt_agg = ac.flat_naive().into_owned().agg_n_unique(ac.groups());
let opt_agg = opt_agg.map(|mut agg| {
agg.rename(&new_name);
agg.rename(&keep_name);
agg.into_series()
});
Ok(opt_agg)
}
GroupByMethod::List => {
let agg = ac.aggregated();
Ok(rename_option_series(Some(agg), &new_name))
Ok(rename_option_series(Some(agg), &keep_name))
}
GroupByMethod::Groups => {
let mut column: ListChunked = ac.groups().as_list_chunked();
column.rename(&new_name);
column.rename(&keep_name);
Ok(Some(column.into_series()))
}
GroupByMethod::Std => {
let agg_s = ac.flat_naive().into_owned().agg_std(ac.groups());
Ok(rename_option_series(agg_s, &new_name))
Ok(rename_option_series(agg_s, &keep_name))
}
GroupByMethod::Var => {
let agg_s = ac.flat_naive().into_owned().agg_var(ac.groups());
Ok(rename_option_series(agg_s, &new_name))
Ok(rename_option_series(agg_s, &keep_name))
}
GroupByMethod::Quantile(_, _) => {
// implemented explicitly in AggQuantile struct
Expand All @@ -145,7 +144,7 @@ impl PhysicalAggregation for AggregationExpr {
match self.agg_type {
GroupByMethod::Mean => {
let series = self.expr.evaluate(df, state)?;
let mut new_name = fmt_groupby_column(series.name(), self.agg_type);
let mut new_name = series.name().to_string();
let agg_s = series.agg_sum(groups);

// If the aggregation is successful,
Expand All @@ -165,10 +164,10 @@ impl PhysicalAggregation for AggregationExpr {
}
GroupByMethod::List => {
let series = self.expr.evaluate(df, state)?;
let new_name = fmt_groupby_column(series.name(), self.agg_type);
let new_name = series.name();
let opt_agg = series.agg_list(groups);
Ok(opt_agg.map(|mut s| {
s.rename(&new_name);
s.rename(new_name);
vec![s]
}))
}
Expand All @@ -187,7 +186,7 @@ impl PhysicalAggregation for AggregationExpr {
GroupByMethod::Mean => {
let series = self.expr.evaluate(final_df, state)?;
let count_name = format!("{}__POLARS_MEAN_COUNT", series.name());
let new_name = fmt_groupby_column(series.name(), self.agg_type);
let new_name = series.name().to_string();
let count = final_df.column(&count_name).unwrap();

let (agg_count, agg_s) =
Expand All @@ -200,7 +199,7 @@ impl PhysicalAggregation for AggregationExpr {
// we now must collect them into a single group
let series = self.expr.evaluate(final_df, state)?;
let ca = series.list().unwrap();
let new_name = fmt_groupby_column(ca.name(), self.agg_type);
let new_name = series.name().to_string();

let mut values = Vec::with_capacity(groups.len());
let mut can_fast_explode = true;
Expand Down Expand Up @@ -253,10 +252,7 @@ impl PhysicalAggregation for AggQuantileExpr {
state: &ExecutionState,
) -> Result<Option<Series>> {
let series = self.expr.evaluate(df, state)?;
let new_name = fmt_groupby_column(
series.name(),
GroupByMethod::Quantile(self.quantile, self.interpol),
);
let new_name = series.name().to_string();
let opt_agg = series.agg_quantile(groups, self.quantile, self.interpol);

let opt_agg = opt_agg.map(|mut agg| {
Expand Down Expand Up @@ -319,12 +315,7 @@ impl PhysicalExpr for AggQuantileExpr {
}

fn to_field(&self, input_schema: &Schema) -> Result<Field> {
let field = self.expr.to_field(input_schema)?;
let new_name = fmt_groupby_column(
field.name(),
GroupByMethod::Quantile(self.quantile, self.interpol),
);
Ok(Field::new(&new_name, field.data_type().clone()))
self.expr.to_field(input_schema)
}

fn as_agg_expr(&self) -> Result<&dyn PhysicalAggregation> {
Expand Down

0 comments on commit 1e8480f

Please sign in to comment.