Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[gen4] More aggregation support #10332

Merged
merged 8 commits into from
May 19, 2022
53 changes: 53 additions & 0 deletions go/vt/vtgate/planbuilder/abstract/queryprojection.go
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,9 @@ type (
OrderExprs []OrderBy
CanPushDownSorting bool
HasStar bool

// AddedColumn keeps a counter for expressions added to solve HAVING expressions the user is not selecting
AddedColumn int
}

// OrderBy contains the expression to used in order by and also if ordering is needed at VTGate level then what the weight_string function expression to be sent down for evaluation.
Expand Down Expand Up @@ -172,6 +175,52 @@ func CreateQPFromSelect(sel *sqlparser.Select) (*QueryProjection, error) {
return qp, nil
}

type AggrRewriter struct {
qp *QueryProjection
Err error
}
frouioui marked this conversation as resolved.
Show resolved Hide resolved

// Rewrite will go through an expression, add aggregations to the QP, and rewrite them to use column offset
func (ar *AggrRewriter) Rewrite() func(*sqlparser.Cursor) bool {
return func(cursor *sqlparser.Cursor) bool {
if ar.Err != nil {
return false
}
sqlNode := cursor.Node()
if sqlparser.IsAggregation(sqlNode) {
fExp := sqlNode.(*sqlparser.FuncExpr)
for offset, expr := range ar.qp.SelectExprs {
ae, err := expr.GetAliasedExpr()
if err != nil {
ar.Err = err
return false
}
if sqlparser.EqualsExpr(ae.Expr, fExp) {
cursor.Replace(sqlparser.Offset(offset))
return false // no need to visit aggregation children
}
}

col := SelectExpr{
Aggr: true,
Col: &sqlparser.AliasedExpr{Expr: fExp},
}
ar.qp.HasAggr = true

cursor.Replace(sqlparser.Offset(len(ar.qp.SelectExprs)))
ar.qp.SelectExprs = append(ar.qp.SelectExprs, col)
ar.qp.AddedColumn++
}

return true
}
}

// AggrRewriter extracts
func (qp *QueryProjection) AggrRewriter() *AggrRewriter {
return &AggrRewriter{qp: qp}
}

func (qp *QueryProjection) addSelectExpressions(sel *sqlparser.Select) error {
for _, selExp := range sel.SelectExprs {
switch selExp := selExp.(type) {
Expand Down Expand Up @@ -549,6 +598,10 @@ func (qp *QueryProjection) AddGroupBy(by GroupBy) {
qp.groupByExprs = append(qp.groupByExprs, by)
}

func (qp *QueryProjection) GetColumnCount() int {
return len(qp.SelectExprs) - qp.AddedColumn
}

func checkForInvalidGroupingExpressions(expr sqlparser.Expr) error {
return sqlparser.Walk(func(node sqlparser.SQLNode) (bool, error) {
if sqlparser.IsAggregation(node) {
Expand Down
22 changes: 14 additions & 8 deletions go/vt/vtgate/planbuilder/horizon_planning.go
Original file line number Diff line number Diff line change
Expand Up @@ -138,18 +138,18 @@ func pushProjections(ctx *plancontext.PlanningContext, plan logicalPlan, selectE
}

func (hp *horizonPlanning) truncateColumnsIfNeeded(ctx *plancontext.PlanningContext, plan logicalPlan) (logicalPlan, error) {
if len(plan.OutputColumns()) == hp.sel.GetColumnCount() {
if len(plan.OutputColumns()) == hp.qp.GetColumnCount() {
return plan, nil
}
switch p := plan.(type) {
case *routeGen4:
p.eroute.SetTruncateColumnCount(hp.sel.GetColumnCount())
p.eroute.SetTruncateColumnCount(hp.qp.GetColumnCount())
case *joinGen4, *semiJoin, *hashJoin:
// since this is a join, we can safely add extra columns and not need to truncate them
case *orderedAggregate:
p.truncateColumnCount = hp.sel.GetColumnCount()
p.truncateColumnCount = hp.qp.GetColumnCount()
case *memorySort:
p.truncater.SetTruncateColumnCount(hp.sel.GetColumnCount())
p.truncater.SetTruncateColumnCount(hp.qp.GetColumnCount())
case *pulloutSubquery:
newUnderlyingPlan, err := hp.truncateColumnsIfNeeded(ctx, p.underlying)
if err != nil {
Expand All @@ -162,7 +162,8 @@ func (hp *horizonPlanning) truncateColumnsIfNeeded(ctx *plancontext.PlanningCont
eSimpleProj: &engine.SimpleProjection{},
}

err := pushProjections(ctx, plan, hp.qp.SelectExprs)
exprs := hp.qp.SelectExprs[0:hp.qp.GetColumnCount()]
err := pushProjections(ctx, plan, exprs)
if err != nil {
return nil, err
}
Expand Down Expand Up @@ -528,6 +529,14 @@ func (hp *horizonPlanning) planAggrUsingOA(
})
}

if hp.sel.Having != nil {
rewriter := hp.qp.AggrRewriter()
sqlparser.Rewrite(hp.sel.Having.Expr, rewriter.Rewrite(), nil)
if rewriter.Err != nil {
return nil, rewriter.Err
}
}

aggregationExprs, err := hp.qp.AggregationExpressions()
if err != nil {
return nil, err
Expand Down Expand Up @@ -1340,9 +1349,6 @@ func pushHaving(ctx *plancontext.PlanningContext, expr sqlparser.Expr, plan logi
case *simpleProjection:
return nil, vterrors.Errorf(vtrpcpb.Code_UNIMPLEMENTED, "unsupported: filtering on results of cross-shard derived table")
case *orderedAggregate:
if sqlparser.ContainsAggregation(expr) {
expr = sqlparser.Rewrite(expr, node.rewriteAggrExpressions(), nil).(sqlparser.Expr)
}
return newFilter(ctx, plan, expr)
}
return nil, vterrors.Errorf(vtrpcpb.Code_INTERNAL, "[BUG] unreachable %T.filtering", plan)
Expand Down
17 changes: 0 additions & 17 deletions go/vt/vtgate/planbuilder/ordered_aggregate.go
Original file line number Diff line number Diff line change
Expand Up @@ -387,20 +387,3 @@ func (oa *orderedAggregate) OutputColumns() []sqlparser.SelectExpr {
func (oa *orderedAggregate) SetTruncateColumnCount(count int) {
oa.truncateColumnCount = count
}

// rewriteAggrExpressions is used when our predicate expression contains aggregation.
// In these cases, we need to rewrite it, so it uses the column output from the ordered aggregate
func (oa *orderedAggregate) rewriteAggrExpressions() func(*sqlparser.Cursor) bool {
return func(cursor *sqlparser.Cursor) bool {
sqlNode := cursor.Node()
if sqlparser.IsAggregation(sqlNode) {
fExp := sqlNode.(*sqlparser.FuncExpr)
for _, aggregate := range oa.aggregates {
if sqlparser.EqualsExpr(aggregate.Expr, fExp) {
cursor.Replace(sqlparser.Offset(aggregate.Col))
}
}
}
return true
}
}
241 changes: 241 additions & 0 deletions go/vt/vtgate/planbuilder/testdata/aggr_cases.txt
Original file line number Diff line number Diff line change
Expand Up @@ -4171,3 +4171,244 @@ Gen4 error: aggregate functions take a single argument 'count(distinct user_id,
]
}
}

# having should be able to add new aggregation expressions in having
"select foo from user group by foo having count(*) = 3"
"unsupported: filtering on results of aggregates"
{
"QueryType": "SELECT",
"Original": "select foo from user group by foo having count(*) = 3",
"Instructions": {
"OperatorType": "SimpleProjection",
"Columns": [
0
],
"Inputs": [
{
"OperatorType": "Filter",
"Predicate": "[1] = 3",
frouioui marked this conversation as resolved.
Show resolved Hide resolved
"Inputs": [
{
"OperatorType": "Aggregate",
"Variant": "Ordered",
"Aggregates": "sum(1) AS count(*)",
"GroupBy": "(0|2)",
"Inputs": [
{
"OperatorType": "Route",
"Variant": "Scatter",
"Keyspace": {
"Name": "user",
"Sharded": true
},
"FieldQuery": "select foo, count(*), weight_string(foo) from `user` where 1 != 1 group by foo, weight_string(foo)",
"OrderBy": "(0|2) ASC",
"Query": "select foo, count(*), weight_string(foo) from `user` group by foo, weight_string(foo) order by foo asc",
"Table": "`user`"
}
]
}
]
}
]
}
}

"select u.id from user u join user_extra ue on ue.id = u.id group by u.id having count(u.name) = 3"
"unsupported: cross-shard query with aggregates"
{
"QueryType": "SELECT",
"Original": "select u.id from user u join user_extra ue on ue.id = u.id group by u.id having count(u.name) = 3",
"Instructions": {
"OperatorType": "SimpleProjection",
"Columns": [
0
],
"Inputs": [
{
"OperatorType": "Filter",
"Predicate": "[1] = 3",
"Inputs": [
{
"OperatorType": "Aggregate",
"Variant": "Ordered",
"Aggregates": "sum(1) AS count(u.`name`)",
"GroupBy": "(0|2)",
"Inputs": [
{
"OperatorType": "Projection",
"Expressions": [
"[COLUMN 0] as id",
"[COLUMN 2] * [COLUMN 3] as count(u.`name`)",
"[COLUMN 1]"
],
"Inputs": [
{
"OperatorType": "Sort",
"Variant": "Memory",
"OrderBy": "(0|1) ASC",
"Inputs": [
{
"OperatorType": "Join",
"Variant": "Join",
"JoinColumnIndexes": "R:1,R:2,L:1,R:0",
"JoinVars": {
"ue_id": 0
},
"TableName": "user_extra_`user`",
"Inputs": [
{
"OperatorType": "Route",
"Variant": "Scatter",
"Keyspace": {
"Name": "user",
"Sharded": true
},
"FieldQuery": "select ue.id, count(*), weight_string(ue.id) from user_extra as ue where 1 != 1 group by ue.id, weight_string(ue.id)",
"Query": "select ue.id, count(*), weight_string(ue.id) from user_extra as ue group by ue.id, weight_string(ue.id)",
"Table": "user_extra"
},
{
"OperatorType": "Route",
"Variant": "EqualUnique",
"Keyspace": {
"Name": "user",
"Sharded": true
},
"FieldQuery": "select count(u.`name`), u.id, weight_string(u.id) from `user` as u where 1 != 1 group by u.id, weight_string(u.id)",
"Query": "select count(u.`name`), u.id, weight_string(u.id) from `user` as u where u.id = :ue_id group by u.id, weight_string(u.id)",
"Table": "`user`",
"Values": [
":ue_id"
],
"Vindex": "user_index"
}
]
}
]
}
]
}
]
}
]
}
]
}
}

"select u.id from user u join user_extra ue on ue.user_id = u.id group by u.id having count(u.name) = 3"
{
"QueryType": "SELECT",
"Original": "select u.id from user u join user_extra ue on ue.user_id = u.id group by u.id having count(u.name) = 3",
"Instructions": {
"OperatorType": "Route",
"Variant": "Scatter",
"Keyspace": {
"Name": "user",
"Sharded": true
},
"FieldQuery": "select u.id from `user` as u join user_extra as ue on ue.user_id = u.id where 1 != 1 group by u.id",
"Query": "select u.id from `user` as u join user_extra as ue on ue.user_id = u.id group by u.id having count(u.`name`) = 3",
"Table": "`user`, user_extra"
}
}
{
"QueryType": "SELECT",
"Original": "select u.id from user u join user_extra ue on ue.user_id = u.id group by u.id having count(u.name) = 3",
"Instructions": {
"OperatorType": "Route",
"Variant": "Scatter",
"Keyspace": {
"Name": "user",
"Sharded": true
},
"FieldQuery": "select u.id from `user` as u, user_extra as ue where 1 != 1 group by u.id",
"Query": "select u.id from `user` as u, user_extra as ue where ue.user_id = u.id group by u.id having count(u.`name`) = 3",
"Table": "`user`, user_extra"
}
}

# only extract the aggregation once, even if used twice
"select u.id from user u join user_extra ue on ue.id = u.id group by u.id having count(*) < 3 and count(*) > 5"
"unsupported: cross-shard query with aggregates"
{
"QueryType": "SELECT",
"Original": "select u.id from user u join user_extra ue on ue.id = u.id group by u.id having count(*) \u003c 3 and count(*) \u003e 5",
"Instructions": {
"OperatorType": "SimpleProjection",
"Columns": [
0
],
"Inputs": [
{
"OperatorType": "Filter",
"Predicate": "[1] \u003c 3 and [1] \u003e 5",
"Inputs": [
{
"OperatorType": "Aggregate",
"Variant": "Ordered",
"Aggregates": "sum(1) AS count(*)",
"GroupBy": "(0|2)",
"Inputs": [
{
"OperatorType": "Projection",
"Expressions": [
"[COLUMN 0] as id",
"[COLUMN 2] * [COLUMN 3] as count(*)",
"[COLUMN 1]"
],
"Inputs": [
{
"OperatorType": "Sort",
"Variant": "Memory",
"OrderBy": "(0|1) ASC",
"Inputs": [
{
"OperatorType": "Join",
"Variant": "Join",
"JoinColumnIndexes": "R:1,R:2,L:1,R:0",
"JoinVars": {
"ue_id": 0
},
"TableName": "user_extra_`user`",
"Inputs": [
{
"OperatorType": "Route",
"Variant": "Scatter",
"Keyspace": {
"Name": "user",
"Sharded": true
},
"FieldQuery": "select ue.id, count(*), weight_string(ue.id) from user_extra as ue where 1 != 1 group by ue.id, weight_string(ue.id)",
"Query": "select ue.id, count(*), weight_string(ue.id) from user_extra as ue group by ue.id, weight_string(ue.id)",
"Table": "user_extra"
},
{
"OperatorType": "Route",
"Variant": "EqualUnique",
"Keyspace": {
"Name": "user",
"Sharded": true
},
"FieldQuery": "select count(*), u.id, weight_string(u.id) from `user` as u where 1 != 1 group by u.id, weight_string(u.id)",
"Query": "select count(*), u.id, weight_string(u.id) from `user` as u where u.id = :ue_id group by u.id, weight_string(u.id)",
"Table": "`user`",
"Values": [
":ue_id"
],
"Vindex": "user_index"
}
]
}
]
}
]
}
]
}
]
}
]
}
}
Loading