Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor Expression and Statement Simplifier #13636

Merged
merged 17 commits into from
Aug 9, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
17 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
9 changes: 9 additions & 0 deletions go/vt/sqlparser/ast_funcs.go
Original file line number Diff line number Diff line change
Expand Up @@ -718,6 +718,15 @@ func NewComparisonExpr(operator ComparisonExprOperator, left, right, escape Expr
}
}

// NewCaseExpr makes a new CaseExpr
func NewCaseExpr(expr Expr, whens []*When, elseExpr Expr) *CaseExpr {
return &CaseExpr{
Expr: expr,
Whens: whens,
Else: elseExpr,
}
}

// NewLimit makes a new Limit
func NewLimit(offset, rowCount int) *Limit {
return &Limit{
Expand Down
4 changes: 3 additions & 1 deletion go/vt/vtgate/planbuilder/simplifier_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,9 @@ import (
// TestSimplifyBuggyQuery should be used to whenever we get a planner bug reported
// It will try to minimize the query to make it easier to understand and work with the bug.
func TestSimplifyBuggyQuery(t *testing.T) {
query := "(select id from unsharded union select id from unsharded_auto) union (select id from user union select name from unsharded)"
query := "select distinct count(distinct a), count(distinct 4) from user left join unsharded on 0 limit 5"
// select 0 from unsharded union select 0 from `user` union select 0 from unsharded
// select 0 from unsharded union (select 0 from `user` union select 0 from unsharded)
vschema := &vschemaWrapper{
v: loadSchema(t, "vschemas/schema.json", true),
version: Gen4,
Expand Down
147 changes: 68 additions & 79 deletions go/vt/vtgate/simplifier/expression_simplifier.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,104 +21,59 @@ import (
"strconv"

"vitess.io/vitess/go/vt/log"

"vitess.io/vitess/go/vt/sqlparser"
)

// CheckF is used to see if the given expression exhibits the sought after issue
type CheckF = func(sqlparser.Expr) bool

func SimplifyExpr(in sqlparser.Expr, test CheckF) (smallestKnown sqlparser.Expr) {
var maxDepth, level int
resetTo := func(e sqlparser.Expr) {
smallestKnown = e
maxDepth = depth(e)
level = 0
func SimplifyExpr(in sqlparser.Expr, test CheckF) sqlparser.Expr {
// since we can't rewrite the top level, wrap the expr in an Exprs object
smallestKnown := sqlparser.Exprs{sqlparser.CloneExpr(in)}

alwaysVisit := func(node, parent sqlparser.SQLNode) bool {
return true
}
resetTo(in)
for level <= maxDepth {
current := sqlparser.CloneExpr(smallestKnown)
nodes, replaceF := getNodesAtLevel(current, level)
replace := func(e sqlparser.Expr, idx int) {
// if we are at the first level, we are replacing the root,
// not rewriting something deep in the tree
if level == 0 {
current = e

up := func(cursor *sqlparser.Cursor) bool {
node := sqlparser.CloneSQLNode(cursor.Node())
s := &shrinker{orig: node}
expr := s.Next()
for expr != nil {
cursor.Replace(expr)

valid := test(smallestKnown[0])
log.Errorf("test: %t: simplified %s to %s, full expr: %s", valid, sqlparser.String(node), sqlparser.String(expr), sqlparser.String(smallestKnown))
if valid {
break // we will still continue trying to simplify other expressions at this level
} else {
// replace `node` in current with the simplified expression
replaceF[idx](e)
// undo the change
cursor.Replace(node)
}
expr = s.Next()
}
simplified := false
for idx, node := range nodes {
// simplify each element and create a new expression with the node replaced by the simplification
// this means that we not only need the node, but also a way to replace the node
s := &shrinker{orig: node}
expr := s.Next()
for expr != nil {
replace(expr, idx)

valid := test(current)
log.Errorf("test: %t - %s", valid, sqlparser.String(current))
if valid {
simplified = true
break // we will still continue trying to simplify other expressions at this level
} else {
// undo the change
replace(node, idx)
}
expr = s.Next()
}
}
if simplified {
resetTo(current)
} else {
level++
}
}
return smallestKnown
}

func getNodesAtLevel(e sqlparser.Expr, level int) (result []sqlparser.Expr, replaceF []func(node sqlparser.SQLNode)) {
lvl := 0
pre := func(cursor *sqlparser.Cursor) bool {
if expr, isExpr := cursor.Node().(sqlparser.Expr); level == lvl && isExpr {
result = append(result, expr)
replaceF = append(replaceF, cursor.ReplacerF())
}
lvl++
return true
}
post := func(cursor *sqlparser.Cursor) bool {
lvl--
return true
}
sqlparser.Rewrite(e, pre, post)
return
}

func depth(e sqlparser.Expr) (depth int) {
lvl := 0
pre := func(cursor *sqlparser.Cursor) bool {
lvl++
if lvl > depth {
depth = lvl
// loop until rewriting introduces no more changes
for {
prevSmallest := sqlparser.CloneExprs(smallestKnown)
sqlparser.SafeRewrite(smallestKnown, alwaysVisit, up)
if sqlparser.Equals.Exprs(prevSmallest, smallestKnown) {
break
}
return true
}
post := func(cursor *sqlparser.Cursor) bool {
lvl--
return true
}
sqlparser.Rewrite(e, pre, post)
return

return smallestKnown[0]
}

type shrinker struct {
orig sqlparser.Expr
queue []sqlparser.Expr
orig sqlparser.SQLNode
queue []sqlparser.SQLNode
}

func (s *shrinker) Next() sqlparser.Expr {
func (s *shrinker) Next() sqlparser.SQLNode {
for {
// first we check if there is already something in the queue.
// note that we are doing a nil check and not a length check here.
Expand All @@ -142,6 +97,10 @@ func (s *shrinker) Next() sqlparser.Expr {
func (s *shrinker) fillQueue() bool {
before := len(s.queue)
switch e := s.orig.(type) {
case *sqlparser.AndExpr:
s.queue = append(s.queue, e.Left, e.Right)
case *sqlparser.OrExpr:
s.queue = append(s.queue, e.Left, e.Right)
case *sqlparser.ComparisonExpr:
s.queue = append(s.queue, e.Left, e.Right)
case *sqlparser.BinaryExpr:
Expand Down Expand Up @@ -228,9 +187,39 @@ func (s *shrinker) fillQueue() bool {
for _, ae := range e.GetArgs() {
s.queue = append(s.queue, ae)
}

clone := sqlparser.CloneAggrFunc(e)
if da, ok := clone.(sqlparser.DistinctableAggr); ok {
if da.IsDistinct() {
da.SetDistinct(false)
s.queue = append(s.queue, clone)
}
}
case *sqlparser.ColName:
// we can try to replace the column with a literal value
s.queue = []sqlparser.Expr{sqlparser.NewIntLiteral("0")}
s.queue = append(s.queue, sqlparser.NewIntLiteral("0"))
case *sqlparser.CaseExpr:
s.queue = append(s.queue, e.Expr, e.Else)
for _, when := range e.Whens {
s.queue = append(s.queue, when.Cond, when.Val)
}

if len(e.Whens) > 1 {
for i := range e.Whens {
whensCopy := sqlparser.CloneSliceOfRefOfWhen(e.Whens)
// replace ith element with last element, then truncate last element
whensCopy[i] = whensCopy[len(whensCopy)-1]
whensCopy = whensCopy[:len(whensCopy)-1]
s.queue = append(s.queue, sqlparser.NewCaseExpr(e.Expr, whensCopy, e.Else))
}
}

if e.Else != nil {
s.queue = append(s.queue, sqlparser.NewCaseExpr(e.Expr, e.Whens, nil))
}
if e.Expr != nil {
s.queue = append(s.queue, sqlparser.NewCaseExpr(nil, e.Whens, e.Else))
}
default:
return false
}
Expand Down