Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(stream,batch,optimizer): support grouping agg call #11006

Merged
merged 2 commits into from
Jul 20, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 7 additions & 7 deletions e2e_test/batch/aggregate/grouping_sets.slt.part
Original file line number Diff line number Diff line change
Expand Up @@ -7,14 +7,14 @@ create table items_sold (brand varchar, size varchar, sales int);
statement ok
insert into items_sold values ('Foo', 'L', 10),('Foo', 'M', 20),('Bar', 'M', 15),('Bar', 'L', '5');

query TTII rowsort
SELECT brand, size, sum(sales), count(distinct sales) FROM items_sold GROUP BY GROUPING SETS ((brand), (size), ());
query TTIIIII rowsort
SELECT brand, size, sum(sales), grouping(brand), grouping(size), grouping(brand,size), count(distinct sales) FROM items_sold GROUP BY GROUPING SETS ((brand), (size), ());
----
Bar NULL 20 2
Foo NULL 30 2
NULL L 15 2
NULL M 35 2
NULL NULL 50 4
Bar NULL 20 0 1 1 2
Foo NULL 30 0 1 1 2
NULL L 15 1 0 2 2
NULL M 35 1 0 2 2
NULL NULL 50 1 1 3 4

statement ok
drop table items_sold;
14 changes: 7 additions & 7 deletions e2e_test/streaming/grouping_sets.slt
Original file line number Diff line number Diff line change
Expand Up @@ -8,16 +8,16 @@ statement ok
insert into items_sold values ('Foo', 'L', 10),('Foo', 'M', 20),('Bar', 'M', 15),('Bar', 'L', '5');

statement ok
create materialized view v as SELECT brand, size, sum(sales), count(distinct sales) FROM items_sold GROUP BY GROUPING SETS ((brand), (size), ());
create materialized view v as SELECT brand, size, sum(sales), grouping(brand) g1, grouping(size) g2, grouping(brand,size) g3, count(distinct sales) FROM items_sold GROUP BY GROUPING SETS ((brand), (size), ());

query TTII rowsort
query TTIIIII rowsort
select * from v;
----
Bar NULL 20 2
Foo NULL 30 2
NULL L 15 2
NULL M 35 2
NULL NULL 50 4
Bar NULL 20 0 1 1 2
Foo NULL 30 0 1 1 2
NULL L 15 1 0 2 2
NULL M 35 1 0 2 2
NULL NULL 50 1 1 3 4

statement ok
drop materialized view v;
Expand Down
1 change: 1 addition & 0 deletions proto/expr.proto
Original file line number Diff line number Diff line change
Expand Up @@ -319,6 +319,7 @@ message AggCall {
PERCENTILE_DISC = 23;
MODE = 24;
LAST_VALUE = 25;
GROUPING = 26;
}
Type type = 1;
repeated InputRef args = 2;
Expand Down
4 changes: 4 additions & 0 deletions src/expr/src/agg/def.rs
Original file line number Diff line number Diff line change
Expand Up @@ -236,6 +236,7 @@ pub enum AggKind {
PercentileCont,
PercentileDisc,
Mode,
Grouping,
}

impl AggKind {
Expand Down Expand Up @@ -266,6 +267,7 @@ impl AggKind {
PbType::PercentileCont => Ok(AggKind::PercentileCont),
PbType::PercentileDisc => Ok(AggKind::PercentileDisc),
PbType::Mode => Ok(AggKind::Mode),
PbType::Grouping => Ok(AggKind::Grouping),
PbType::Unspecified => bail!("Unrecognized agg."),
}
}
Expand Down Expand Up @@ -296,6 +298,7 @@ impl AggKind {
Self::VarSamp => PbType::VarSamp,
Self::PercentileCont => PbType::PercentileCont,
Self::PercentileDisc => PbType::PercentileDisc,
Self::Grouping => PbType::Grouping,
Self::Mode => PbType::Mode,
}
}
Expand Down Expand Up @@ -332,6 +335,7 @@ pub mod agg_kinds {
| AggKind::StddevSamp
| AggKind::VarPop
| AggKind::VarSamp
| AggKind::Grouping
};
}
pub use rewritten;
Expand Down
19 changes: 19 additions & 0 deletions src/frontend/planner_test/tests/testdata/input/grouping_sets.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -19,3 +19,22 @@
expected_outputs:
- batch_plan
- stream_plan
- name: grouping agg calls
sql: |
create table items_sold (brand varchar, size varchar, sales int);
SELECT brand, size, sum(sales), grouping(brand) g1, grouping(size) g2, grouping(brand,size) g3, count(distinct sales) FROM items_sold GROUP BY GROUPING SETS ((brand), (size), ());
expected_outputs:
- batch_plan
- stream_plan
- name: too many arguments for grouping error
sql: |
create table items_sold (brand varchar, size varchar, sales int);
SELECT brand, size, sum(sales), grouping(brand, size, brand, size, brand, size, brand, size, brand, size, brand, size, brand, size, brand, size, brand, size, brand, size, brand, size, brand, size, brand, size, brand, size, brand, size, brand, size, size) FROM items_sold GROUP BY GROUPING SETS ((brand), (size), ());
expected_outputs:
- planner_error
- name: currently not support using grouping in query without grouping sets.
sql: |
create table items_sold (brand varchar, size varchar, sales int);
SELECT brand, size, sum(sales), grouping(size) FROM items_sold GROUP BY brand, size;
expected_outputs:
- planner_error
32 changes: 32 additions & 0 deletions src/frontend/planner_test/tests/testdata/output/grouping_sets.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -57,3 +57,35 @@
└─StreamExpand { column_subsets: [[items_sold.size], [items_sold.brand], []] }
└─StreamProject { exprs: [items_sold.size, items_sold.brand, items_sold.sales, items_sold._row_id] }
└─StreamTableScan { table: items_sold, columns: [items_sold.brand, items_sold.size, items_sold.sales, items_sold._row_id], pk: [items_sold._row_id], dist: UpstreamHashShard(items_sold._row_id) }
- name: grouping agg calls
sql: |
create table items_sold (brand varchar, size varchar, sales int);
SELECT brand, size, sum(sales), grouping(brand) g1, grouping(size) g2, grouping(brand,size) g3, count(distinct sales) FROM items_sold GROUP BY GROUPING SETS ((brand), (size), ());
batch_plan: |-
BatchExchange { order: [], dist: Single }
└─BatchProject { exprs: [items_sold.brand, items_sold.size, sum(sum(items_sold.sales)), Case((0:Int64 = flag), 0:Int32, (1:Int64 = flag), 1:Int32, (2:Int64 = flag), 1:Int32) as $expr1, Case((0:Int64 = flag), 1:Int32, (1:Int64 = flag), 0:Int32, (2:Int64 = flag), 1:Int32) as $expr2, Case((0:Int64 = flag), 1:Int32, (1:Int64 = flag), 2:Int32, (2:Int64 = flag), 3:Int32) as $expr3, count(items_sold.sales)] }
└─BatchHashAgg { group_key: [items_sold.brand, items_sold.size, flag], aggs: [sum(sum(items_sold.sales)), count(items_sold.sales)] }
└─BatchExchange { order: [], dist: HashShard(items_sold.brand, items_sold.size, flag) }
└─BatchHashAgg { group_key: [items_sold.brand, items_sold.size, items_sold.sales, flag], aggs: [sum(items_sold.sales)] }
└─BatchExchange { order: [], dist: HashShard(items_sold.brand, items_sold.size, items_sold.sales, flag) }
└─BatchExpand { column_subsets: [[items_sold.brand], [items_sold.size], []] }
└─BatchScan { table: items_sold, columns: [items_sold.brand, items_sold.size, items_sold.sales], distribution: SomeShard }
stream_plan: |-
StreamMaterialize { columns: [brand, size, sum, g1, g2, g3, count, flag(hidden)], stream_key: [brand, size, flag], pk_columns: [brand, size, flag], pk_conflict: NoCheck }
└─StreamProject { exprs: [items_sold.brand, items_sold.size, sum(items_sold.sales), Case((0:Int64 = flag), 0:Int32, (1:Int64 = flag), 1:Int32, (2:Int64 = flag), 1:Int32) as $expr1, Case((0:Int64 = flag), 1:Int32, (1:Int64 = flag), 0:Int32, (2:Int64 = flag), 1:Int32) as $expr2, Case((0:Int64 = flag), 1:Int32, (1:Int64 = flag), 2:Int32, (2:Int64 = flag), 3:Int32) as $expr3, count(distinct items_sold.sales), flag] }
└─StreamHashAgg { group_key: [items_sold.brand, items_sold.size, flag], aggs: [sum(items_sold.sales), count(distinct items_sold.sales), count] }
└─StreamExchange { dist: HashShard(items_sold.brand, items_sold.size, flag) }
└─StreamExpand { column_subsets: [[items_sold.brand], [items_sold.size], []] }
└─StreamTableScan { table: items_sold, columns: [items_sold.brand, items_sold.size, items_sold.sales, items_sold._row_id], pk: [items_sold._row_id], dist: UpstreamHashShard(items_sold._row_id) }
- name: too many arguments for grouping error
sql: |
create table items_sold (brand varchar, size varchar, sales int);
SELECT brand, size, sum(sales), grouping(brand, size, brand, size, brand, size, brand, size, brand, size, brand, size, brand, size, brand, size, brand, size, brand, size, brand, size, brand, size, brand, size, brand, size, brand, size, brand, size, size) FROM items_sold GROUP BY GROUPING SETS ((brand), (size), ());
planner_error: 'Invalid input syntax: GROUPING must have fewer than 32 arguments'
- name: currently not support using grouping in query without grouping sets.
sql: |
create table items_sold (brand varchar, size varchar, sales int);
SELECT brand, size, sum(sales), grouping(size) FROM items_sold GROUP BY brand, size;
planner_error: |-
Not supported: GROUPING must be used in a query with grouping sets
HINT: try to use grouping sets instead
2 changes: 1 addition & 1 deletion src/frontend/src/expr/agg_call.rs
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,7 @@ impl AggCall {
_ => return Err(err()),
},
(AggKind::PercentileDisc | AggKind::Mode, [input]) => input.clone(),

(AggKind::Grouping, _) => Int32,
// other functions are handled by signature map
_ => {
let args = args.iter().map(|t| t.into()).collect::<Vec<_>>();
Expand Down
6 changes: 6 additions & 0 deletions src/frontend/src/expr/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -133,6 +133,12 @@ impl ExprImpl {
Literal::new(Some(v.to_scalar_value()), DataType::Int32).into()
}

/// A literal bigint value
#[inline(always)]
pub fn literal_bigint(v: i64) -> Self {
Literal::new(Some(v.to_scalar_value()), DataType::Int64).into()
}

/// A literal float64 value.
#[inline(always)]
pub fn literal_f64(v: f64) -> Self {
Expand Down
21 changes: 20 additions & 1 deletion src/frontend/src/optimizer/plan_node/logical_agg.rs
Original file line number Diff line number Diff line change
Expand Up @@ -423,6 +423,26 @@ impl LogicalAggBuilder {
let filter = filter.rewrite_expr(self);
self.is_in_filter_clause = false;

if matches!(agg_kind, AggKind::Grouping) {
if self.grouping_sets.is_empty() {
return Err(ErrorCode::NotSupported(
"GROUPING must be used in a query with grouping sets".into(),
"try to use grouping sets instead".into(),
));
}
if inputs.len() >= 32 {
return Err(ErrorCode::InvalidInputSyntax(
"GROUPING must have fewer than 32 arguments".into(),
));
}
if inputs.iter().any(|x| self.try_as_group_expr(x).is_none()) {
return Err(ErrorCode::InvalidInputSyntax(
"arguments to GROUPING must be grouping expressions of the associated query level"
.into(),
));
}
}

let inputs: Vec<_> = inputs
.iter()
.map(|expr| {
Expand Down Expand Up @@ -647,7 +667,6 @@ impl LogicalAggBuilder {
_ => unreachable!(),
}
}

_ => Ok(self
.push_agg_call(PlanAggCall {
agg_kind,
Expand Down
89 changes: 74 additions & 15 deletions src/frontend/src/optimizer/rule/grouping_sets_to_expand_rule.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,11 +14,13 @@

use fixedbitset::FixedBitSet;
use itertools::Itertools;
use risingwave_common::types::DataType;
use risingwave_common::util::column_index_mapping::ColIndexMapping;
use risingwave_expr::agg::AggKind;

use super::super::plan_node::*;
use super::{BoxedRule, Rule};
use crate::expr::{Expr, InputRef};
use crate::expr::{Expr, ExprImpl, ExprType, FunctionCall, InputRef};
use crate::optimizer::plan_node::generic::{Agg, GenericPlanNode, GenericPlanRef};
pub struct GroupingSetsToExpandRule {}

Expand Down Expand Up @@ -75,7 +77,7 @@ impl Rule for GroupingSetsToExpandRule {
let agg = Self::prune_column_for_agg(agg);
let (agg_calls, mut group_keys, grouping_sets, input) = agg.decompose();

let original_group_keys_num = group_keys.len();
let flag_col_idx = group_keys.len();
let input_schema_len = input.schema().len();

// TODO: support GROUPING expression.
Expand All @@ -85,35 +87,92 @@ impl Rule for GroupingSetsToExpandRule {
.iter()
.map(|set| set.indices().collect_vec())
.collect_vec();
let expand = LogicalExpand::create(input, column_subset);

let expand = LogicalExpand::create(input, column_subset.clone());
// Add the expand flag.
group_keys.extend(std::iter::once(expand.schema().len() - 1));

let mut input_col_change =
ColIndexMapping::with_shift_offset(input_schema_len, input_schema_len as isize);

// Shift agg_call to the original input columns
let new_agg_calls = agg_calls
.iter()
.cloned()
.map(|mut agg_call| {
// Grouping agg calls need to be transformed into a project expression, and other agg calls
// need to shift their `input_ref`.
let mut project_exprs = vec![];
let mut new_agg_calls = vec![];
for mut agg_call in agg_calls {
// Deal with grouping agg call for grouping sets.
if agg_call.agg_kind == AggKind::Grouping {
let mut grouping_values = vec![];
let args = agg_call
.inputs
.iter()
.map(|input_ref| input_ref.index)
.collect_vec();
for subset in &column_subset {
let mut value = 0;
for arg in &args {
value <<= 1;
if !subset.contains(arg) {
value += 1;
}
}
grouping_values.push(value);
}

let mut case_inputs = vec![];
for (i, grouping_value) in grouping_values.into_iter().enumerate() {
let condition = ExprImpl::FunctionCall(
FunctionCall::new_unchecked(
ExprType::Equal,
vec![
ExprImpl::literal_bigint(i as i64),
ExprImpl::InputRef(
InputRef::new(flag_col_idx, DataType::Int64).into(),
),
],
DataType::Boolean,
)
.into(),
);
let value = ExprImpl::literal_int(grouping_value);
case_inputs.push(condition);
case_inputs.push(value);
}

let case_expr = ExprImpl::FunctionCall(
FunctionCall::new_unchecked(ExprType::Case, case_inputs, DataType::Int32)
.into(),
);
project_exprs.push(case_expr);
} else {
// Shift agg_call to the original input columns
agg_call.inputs.iter_mut().for_each(|i| {
*i = InputRef::new(input_col_change.map(i.index()), i.return_type())
});
agg_call.order_by.iter_mut().for_each(|o| {
o.column_index = input_col_change.map(o.column_index);
});
agg_call.filter = agg_call.filter.rewrite_expr(&mut input_col_change);
agg_call
})
.collect();
let agg_return_type = agg_call.return_type.clone();
new_agg_calls.push(agg_call);
project_exprs.push(ExprImpl::InputRef(
InputRef::new(group_keys.len() + new_agg_calls.len() - 1, agg_return_type)
.into(),
));
}
}

let new_agg = Agg::new(new_agg_calls, group_keys, expand);
let project_exprs = (0..flag_col_idx)
.map(|i| {
ExprImpl::InputRef(
InputRef::new(i, new_agg.schema().fields()[i].data_type.clone()).into(),
)
})
chenzl25 marked this conversation as resolved.
Show resolved Hide resolved
.chain(project_exprs)
.collect();

let mut output_fields = FixedBitSet::with_capacity(new_agg.schema().len());
output_fields.toggle(original_group_keys_num);
output_fields.toggle_range(..);
let project = LogicalProject::with_out_fields(new_agg.into(), &output_fields);
let project = LogicalProject::new(new_agg.into(), project_exprs);

Some(project.into())
}
Expand Down
Loading