Skip to content

Commit

Permalink
more partitioned groupby (#3355)
Browse files Browse the repository at this point in the history
  • Loading branch information
ritchie46 committed May 10, 2022
1 parent bfbb7a7 commit 3700e69
Show file tree
Hide file tree
Showing 11 changed files with 191 additions and 21 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ fn run_partitions(
let agg_columns = phys_aggs
.iter()
.map(|expr| {
let agg_expr = expr.as_partitioned_aggregator()?;
let agg_expr = expr.as_partitioned_aggregator().unwrap();
let agg = agg_expr.evaluate_partitioned(&df, groups, state)?;
if agg.len() != groups.len() {
Err(PolarsError::ComputeError(
Expand Down Expand Up @@ -245,7 +245,7 @@ impl Executor for PartitionGroupByExec {
.zip(&df.get_columns()[self.keys.len()..])
.map(|(expr, partitioned_s)| {
let agg_expr = expr.as_partitioned_aggregator().unwrap();
agg_expr.finalize(partitioned_s, groups, state)
agg_expr.finalize(partitioned_s.clone(), groups, state)
})
.collect();

Expand Down
44 changes: 38 additions & 6 deletions polars/polars-lazy/src/physical_plan/expressions/aggregation.rs
Original file line number Diff line number Diff line change
Expand Up @@ -163,8 +163,8 @@ impl PhysicalExpr for AggregationExpr {
self.expr.to_field(input_schema)
}

fn as_partitioned_aggregator(&self) -> Result<&dyn PartitionedAggregation> {
Ok(self)
fn as_partitioned_aggregator(&self) -> Option<&dyn PartitionedAggregation> {
Some(self)
}
}

Expand All @@ -180,10 +180,11 @@ impl PartitionedAggregation for AggregationExpr {
groups: &GroupsProxy,
state: &ExecutionState,
) -> Result<Series> {
let expr = self.expr.as_partitioned_aggregator().unwrap();
let series = expr.evaluate_partitioned(df, groups, state)?;
match self.agg_type {
#[cfg(feature = "dtype-struct")]
GroupByMethod::Mean => {
let series = self.expr.evaluate(df, state)?;
let new_name = series.name().to_string();
let mut agg_s = series.agg_sum(groups);
agg_s.rename(&new_name);
Expand All @@ -203,19 +204,50 @@ impl PartitionedAggregation for AggregationExpr {
}
}
GroupByMethod::List => {
let series = self.expr.evaluate(df, state)?;
let new_name = series.name();
let mut agg = series.agg_list(groups);
agg.rename(new_name);
Ok(agg)
}
_ => Ok(self.evaluate_on_groups(df, groups, state)?.aggregated()),
GroupByMethod::First => {
let mut agg = series.agg_first(groups);
agg.rename(series.name());
Ok(agg)
}
GroupByMethod::Last => {
let mut agg = series.agg_last(groups);
agg.rename(series.name());
Ok(agg)
}
GroupByMethod::Max => {
let mut agg = series.agg_max(groups);
agg.rename(series.name());
Ok(agg)
}
GroupByMethod::Min => {
let mut agg = series.agg_min(groups);
agg.rename(series.name());
Ok(agg)
}
GroupByMethod::Sum => {
let mut agg = series.agg_sum(groups);
agg.rename(series.name());
Ok(agg)
}
GroupByMethod::Count => {
let mut ca = groups.group_count();
ca.rename(series.name());
Ok(ca.into_series())
}
_ => {
unimplemented!()
}
}
}

fn finalize(
&self,
partitioned: &Series,
partitioned: Series,
groups: &GroupsProxy,
_state: &ExecutionState,
) -> Result<Series> {
Expand Down
6 changes: 3 additions & 3 deletions polars/polars-lazy/src/physical_plan/expressions/alias.rs
Original file line number Diff line number Diff line change
Expand Up @@ -63,8 +63,8 @@ impl PhysicalExpr for AliasExpr {
))
}

fn as_partitioned_aggregator(&self) -> Result<&dyn PartitionedAggregation> {
Ok(self)
fn as_partitioned_aggregator(&self) -> Option<&dyn PartitionedAggregation> {
Some(self)
}
}

Expand All @@ -84,7 +84,7 @@ impl PartitionedAggregation for AliasExpr {

fn finalize(
&self,
partitioned: &Series,
partitioned: Series,
groups: &GroupsProxy,
state: &ExecutionState,
) -> Result<Series> {
Expand Down
28 changes: 28 additions & 0 deletions polars/polars-lazy/src/physical_plan/expressions/binary.rs
Original file line number Diff line number Diff line change
Expand Up @@ -315,6 +315,10 @@ impl PhysicalExpr for BinaryExpr {
self.expr.to_field(input_schema, Context::Default)
}

fn as_partitioned_aggregator(&self) -> Option<&dyn PartitionedAggregation> {
Some(self)
}

#[cfg(feature = "parquet")]
fn as_stats_evaluator(&self) -> Option<&dyn polars_io::predicates::StatsEvaluator> {
Some(self)
Expand Down Expand Up @@ -493,3 +497,27 @@ mod stats {
}
}
}

impl PartitionedAggregation for BinaryExpr {
fn evaluate_partitioned(
&self,
df: &DataFrame,
groups: &GroupsProxy,
state: &ExecutionState,
) -> Result<Series> {
let left = self.left.as_partitioned_aggregator().unwrap();
let right = self.right.as_partitioned_aggregator().unwrap();
let left = left.evaluate_partitioned(df, groups, state)?;
let right = right.evaluate_partitioned(df, groups, state)?;
apply_operator(&left, &right, self.op)
}

fn finalize(
&self,
partitioned: Series,
_groups: &GroupsProxy,
_state: &ExecutionState,
) -> Result<Series> {
Ok(partitioned)
}
}
25 changes: 25 additions & 0 deletions polars/polars-lazy/src/physical_plan/expressions/cast.rs
Original file line number Diff line number Diff line change
Expand Up @@ -89,4 +89,29 @@ impl PhysicalExpr for CastExpr {
fld
})
}

fn as_partitioned_aggregator(&self) -> Option<&dyn PartitionedAggregation> {
Some(self)
}
}

impl PartitionedAggregation for CastExpr {
fn evaluate_partitioned(
&self,
df: &DataFrame,
groups: &GroupsProxy,
state: &ExecutionState,
) -> Result<Series> {
let e = self.input.as_partitioned_aggregator().unwrap();
e.evaluate_partitioned(df, groups, state)
}

fn finalize(
&self,
partitioned: Series,
_groups: &GroupsProxy,
_state: &ExecutionState,
) -> Result<Series> {
Ok(partitioned)
}
}
25 changes: 25 additions & 0 deletions polars/polars-lazy/src/physical_plan/expressions/column.rs
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,11 @@ impl PhysicalExpr for ColumnExpr {
let s = self.evaluate(df, state)?;
Ok(AggregationContext::new(s, Cow::Borrowed(groups), false))
}

fn as_partitioned_aggregator(&self) -> Option<&dyn PartitionedAggregation> {
Some(self)
}

fn to_field(&self, input_schema: &Schema) -> Result<Field> {
let field = input_schema.get_field(&self.0).ok_or_else(|| {
PolarsError::NotFound(format!(
Expand All @@ -98,3 +103,23 @@ impl PhysicalExpr for ColumnExpr {
Ok(field)
}
}

impl PartitionedAggregation for ColumnExpr {
fn evaluate_partitioned(
&self,
df: &DataFrame,
_groups: &GroupsProxy,
state: &ExecutionState,
) -> Result<Series> {
self.evaluate(df, state)
}

fn finalize(
&self,
partitioned: Series,
_groups: &GroupsProxy,
_state: &ExecutionState,
) -> Result<Series> {
Ok(partitioned)
}
}
6 changes: 3 additions & 3 deletions polars/polars-lazy/src/physical_plan/expressions/count.rs
Original file line number Diff line number Diff line change
Expand Up @@ -55,8 +55,8 @@ impl PhysicalExpr for CountExpr {
Ok(Field::new("count", DataType::UInt32))
}

fn as_partitioned_aggregator(&self) -> Result<&dyn PartitionedAggregation> {
Ok(self)
fn as_partitioned_aggregator(&self) -> Option<&dyn PartitionedAggregation> {
Some(self)
}
}

Expand All @@ -76,7 +76,7 @@ impl PartitionedAggregation for CountExpr {
#[allow(clippy::ptr_arg)]
fn finalize(
&self,
partitioned: &Series,
partitioned: Series,
groups: &GroupsProxy,
_state: &ExecutionState,
) -> Result<Series> {
Expand Down
24 changes: 24 additions & 0 deletions polars/polars-lazy/src/physical_plan/expressions/literal.rs
Original file line number Diff line number Diff line change
Expand Up @@ -129,6 +129,10 @@ impl PhysicalExpr for LiteralExpr {
Ok(AggregationContext::from_literal(s, Cow::Borrowed(groups)))
}

fn as_partitioned_aggregator(&self) -> Option<&dyn PartitionedAggregation> {
Some(self)
}

fn to_field(&self, _input_schema: &Schema) -> Result<Field> {
use LiteralValue::*;
let name = "literal";
Expand Down Expand Up @@ -160,3 +164,23 @@ impl PhysicalExpr for LiteralExpr {
Ok(field)
}
}

impl PartitionedAggregation for LiteralExpr {
fn evaluate_partitioned(
&self,
df: &DataFrame,
_groups: &GroupsProxy,
state: &ExecutionState,
) -> Result<Series> {
self.evaluate(df, state)
}

fn finalize(
&self,
partitioned: Series,
_groups: &GroupsProxy,
_state: &ExecutionState,
) -> Result<Series> {
Ok(partitioned)
}
}
9 changes: 3 additions & 6 deletions polars/polars-lazy/src/physical_plan/expressions/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -486,11 +486,8 @@ pub trait PhysicalExpr: Send + Sync {
fn to_field(&self, input_schema: &Schema) -> Result<Field>;

/// Convert to a partitioned aggregator.
fn as_partitioned_aggregator(&self) -> Result<&dyn PartitionedAggregation> {
let e = self.as_expression();
Err(PolarsError::InvalidOperation(
format!("{:?} is not an agg expression", e).into(),
))
fn as_partitioned_aggregator(&self) -> Option<&dyn PartitionedAggregation> {
None
}

/// Can take &dyn Statistics and determine of a file should be
Expand Down Expand Up @@ -540,7 +537,7 @@ pub trait PartitionedAggregation: Send + Sync + PhysicalExpr {
#[allow(clippy::ptr_arg)]
fn finalize(
&self,
partitioned: &Series,
partitioned: Series,
groups: &GroupsProxy,
state: &ExecutionState,
) -> Result<Series>;
Expand Down
6 changes: 5 additions & 1 deletion polars/polars-lazy/src/physical_plan/planner/lp.rs
Original file line number Diff line number Diff line change
Expand Up @@ -373,7 +373,11 @@ impl DefaultPlanner {
| AAggExpr::Count(_)
)
},
Column(_) | Alias(_, _) | Count => {
BinaryExpr {left, right, ..} => {
!has_aexpr(*left, expr_arena, |ae| matches!(ae, AExpr::Agg(_))) && !has_aexpr(*right, expr_arena, |ae| matches!(ae, AExpr::Agg(_)))
}

Column(_) | Alias(_, _) | Count | Literal(_) | Cast {..} => {
true
}
_ => {
Expand Down
35 changes: 35 additions & 0 deletions polars/polars-lazy/src/tests/queries.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2030,3 +2030,38 @@ fn test_partitioned_gb_mean() -> Result<()> {

Ok(())
}

#[test]
fn test_partitioned_gb_binary() -> Result<()> {
// don't move these to integration tests
let df = df![
"col" => (0..20).map(|_| Some(0)).collect::<Int32Chunked>().into_series(),
]?;

let out = df
.clone()
.lazy()
.groupby([col("col")])
.agg([(col("col") + lit(10)).sum().alias("sum")])
.collect()?;

assert!(out.frame_equal(&df![
"col" => [0],
"sum" => [200],
]?));

let out = df
.lazy()
.groupby([col("col")])
.agg([(col("col").cast(DataType::Float32) + lit(10))
.sum()
.alias("sum")])
.collect()?;

assert!(out.frame_equal(&df![
"col" => [0],
"sum" => [200.0 as f32],
]?));

Ok(())
}

0 comments on commit 3700e69

Please sign in to comment.