Skip to content

Commit

Permalink
bindinfo: fix bugs when using bindings (#14263)
Browse files Browse the repository at this point in the history
  • Loading branch information
alivxxx committed Feb 4, 2020
1 parent b33e1a2 commit 1cc43bb
Show file tree
Hide file tree
Showing 8 changed files with 92 additions and 168 deletions.
185 changes: 46 additions & 139 deletions bindinfo/bind.go
Expand Up @@ -15,155 +15,62 @@ package bindinfo

import "github.com/pingcap/parser/ast"

// BindHint will add hints for originStmt according to hintedStmt' hints.
func BindHint(originStmt, hintedStmt ast.StmtNode) ast.StmtNode {
switch x := originStmt.(type) {
case *ast.SelectStmt:
return selectBind(x, hintedStmt.(*ast.SelectStmt))
default:
return originStmt
}
}

func selectBind(originalNode, hintedNode *ast.SelectStmt) *ast.SelectStmt {
if hintedNode.TableHints != nil {
originalNode.TableHints = hintedNode.TableHints
}
if originalNode.From != nil {
originalNode.From.TableRefs = resultSetNodeBind(originalNode.From.TableRefs, hintedNode.From.TableRefs).(*ast.Join)
}
if originalNode.Where != nil {
originalNode.Where = exprBind(originalNode.Where, hintedNode.Where).(ast.ExprNode)
}

if originalNode.Having != nil {
originalNode.Having.Expr = exprBind(originalNode.Having.Expr, hintedNode.Having.Expr)
}

if originalNode.OrderBy != nil {
originalNode.OrderBy = orderByBind(originalNode.OrderBy, hintedNode.OrderBy)
}

if originalNode.Fields != nil {
origFields := originalNode.Fields.Fields
hintFields := hintedNode.Fields.Fields
for idx := range origFields {
origFields[idx].Expr = exprBind(origFields[idx].Expr, hintFields[idx].Expr)
}
}
return originalNode
// HintsSet contains all hints of a query.
type HintsSet struct {
tableHints [][]*ast.TableOptimizerHint // Slice offset is the traversal order of `SelectStmt` in the ast.
indexHints [][]*ast.IndexHint // Slice offset is the traversal order of `TableName` in the ast.
}

func orderByBind(originalNode, hintedNode *ast.OrderByClause) *ast.OrderByClause {
for idx := 0; idx < len(originalNode.Items); idx++ {
originalNode.Items[idx].Expr = exprBind(originalNode.Items[idx].Expr, hintedNode.Items[idx].Expr)
}
return originalNode
type hintProcessor struct {
*HintsSet
// bindHint2Ast indicates the behavior of the processor, `true` for bind hint to ast, `false` for extract hint from ast.
bindHint2Ast bool
tableCounter int
indexCounter int
}

func exprBind(originalNode, hintedNode ast.ExprNode) ast.ExprNode {
switch v := originalNode.(type) {
case *ast.SubqueryExpr:
if v.Query != nil {
v.Query = resultSetNodeBind(v.Query, hintedNode.(*ast.SubqueryExpr).Query)
}
case *ast.ExistsSubqueryExpr:
if v.Sel != nil {
v.Sel.(*ast.SubqueryExpr).Query = resultSetNodeBind(v.Sel.(*ast.SubqueryExpr).Query, hintedNode.(*ast.ExistsSubqueryExpr).Sel.(*ast.SubqueryExpr).Query)
}
case *ast.PatternInExpr:
if v.Sel != nil {
v.Sel.(*ast.SubqueryExpr).Query = resultSetNodeBind(v.Sel.(*ast.SubqueryExpr).Query, hintedNode.(*ast.PatternInExpr).Sel.(*ast.SubqueryExpr).Query)
}
case *ast.BinaryOperationExpr:
if v.L != nil {
v.L = exprBind(v.L, hintedNode.(*ast.BinaryOperationExpr).L)
}
if v.R != nil {
v.R = exprBind(v.R, hintedNode.(*ast.BinaryOperationExpr).R)
}
case *ast.IsNullExpr:
if v.Expr != nil {
v.Expr = exprBind(v.Expr, hintedNode.(*ast.IsNullExpr).Expr)
}
case *ast.IsTruthExpr:
if v.Expr != nil {
v.Expr = exprBind(v.Expr, hintedNode.(*ast.IsTruthExpr).Expr)
}
case *ast.PatternLikeExpr:
if v.Pattern != nil {
v.Pattern = exprBind(v.Pattern, hintedNode.(*ast.PatternLikeExpr).Pattern)
}
case *ast.CompareSubqueryExpr:
if v.L != nil {
v.L = exprBind(v.L, hintedNode.(*ast.CompareSubqueryExpr).L)
}
if v.R != nil {
v.R = exprBind(v.R, hintedNode.(*ast.CompareSubqueryExpr).R)
}
case *ast.BetweenExpr:
if v.Left != nil {
v.Left = exprBind(v.Left, hintedNode.(*ast.BetweenExpr).Left)
}
if v.Right != nil {
v.Right = exprBind(v.Right, hintedNode.(*ast.BetweenExpr).Right)
}
case *ast.UnaryOperationExpr:
if v.V != nil {
v.V = exprBind(v.V, hintedNode.(*ast.UnaryOperationExpr).V)
}
case *ast.CaseExpr:
if v.Value != nil {
v.Value = exprBind(v.Value, hintedNode.(*ast.CaseExpr).Value)
}
if v.ElseClause != nil {
v.ElseClause = exprBind(v.ElseClause, hintedNode.(*ast.CaseExpr).ElseClause)
func (hp *hintProcessor) Enter(in ast.Node) (ast.Node, bool) {
switch v := in.(type) {
case *ast.SelectStmt:
if hp.bindHint2Ast {
if hp.tableCounter < len(hp.tableHints) {
v.TableHints = hp.tableHints[hp.tableCounter]
} else {
v.TableHints = nil
}
hp.tableCounter++
} else {
hp.tableHints = append(hp.tableHints, v.TableHints)
}
case *ast.TableName:
if hp.bindHint2Ast {
if hp.indexCounter < len(hp.indexHints) {
v.IndexHints = hp.indexHints[hp.indexCounter]
} else {
v.IndexHints = nil
}
hp.indexCounter++
} else {
hp.indexHints = append(hp.indexHints, v.IndexHints)
}
}
return originalNode
return in, false
}

func resultSetNodeBind(originalNode, hintedNode ast.ResultSetNode) ast.ResultSetNode {
switch x := originalNode.(type) {
case *ast.Join:
return joinBind(x, hintedNode.(*ast.Join))
case *ast.TableSource:
ts, _ := hintedNode.(*ast.TableSource)
switch v := x.Source.(type) {
case *ast.SelectStmt:
x.Source = selectBind(v, ts.Source.(*ast.SelectStmt))
case *ast.UnionStmt:
x.Source = unionSelectBind(v, hintedNode.(*ast.TableSource).Source.(*ast.UnionStmt))
case *ast.TableName:
x.Source.(*ast.TableName).IndexHints = ts.Source.(*ast.TableName).IndexHints
}
return x
case *ast.SelectStmt:
return selectBind(x, hintedNode.(*ast.SelectStmt))
case *ast.UnionStmt:
return unionSelectBind(x, hintedNode.(*ast.UnionStmt))
default:
return x
}
func (hp *hintProcessor) Leave(in ast.Node) (ast.Node, bool) {
return in, true
}

func joinBind(originalNode, hintedNode *ast.Join) *ast.Join {
if originalNode.Left != nil {
originalNode.Left = resultSetNodeBind(originalNode.Left, hintedNode.Left)
}

if hintedNode.Right != nil {
originalNode.Right = resultSetNodeBind(originalNode.Right, hintedNode.Right)
}

return originalNode
// CollectHint collects hints for a statement.
func CollectHint(in ast.StmtNode) *HintsSet {
hp := hintProcessor{HintsSet: &HintsSet{tableHints: make([][]*ast.TableOptimizerHint, 0, 4), indexHints: make([][]*ast.IndexHint, 0, 4)}}
in.Accept(&hp)
return hp.HintsSet
}

func unionSelectBind(originalNode, hintedNode *ast.UnionStmt) ast.ResultSetNode {
selects := originalNode.SelectList.Selects
for i := len(selects) - 1; i >= 0; i-- {
originalNode.SelectList.Selects[i] = selectBind(selects[i], hintedNode.SelectList.Selects[i])
}

return originalNode
// BindHint will add hints for stmt according to the hints in `hintsSet`.
func BindHint(stmt ast.StmtNode, hintsSet *HintsSet) ast.StmtNode {
hp := hintProcessor{HintsSet: hintsSet, bindHint2Ast: true}
stmt.Accept(&hp)
return stmt
}
13 changes: 11 additions & 2 deletions bindinfo/bind_test.go
Expand Up @@ -137,9 +137,9 @@ func (s *testSuite) TestBindParse(c *C) {
c.Check(bindData.UpdateTime, NotNil)

// Test fields with quotes or slashes.
sql = `CREATE GLOBAL BINDING FOR select * from t where a BETWEEN "a" and "b" USING select * from t use index(idx) where a BETWEEN "a\nb\rc\td\0e" and 'x'`
sql = `CREATE GLOBAL BINDING FOR select * from t where i BETWEEN "a" and "b" USING select * from t use index(index_t) where i BETWEEN "a\nb\rc\td\0e" and 'x'`
tk.MustExec(sql)
tk.MustExec(`DROP global binding for select * from t use index(idx) where a BETWEEN "a\nb\rc\td\0e" and "x"`)
tk.MustExec(`DROP global binding for select * from t use index(index_t) where i BETWEEN "a\nb\rc\td\0e" and "x"`)
}

func (s *testSuite) TestGlobalBinding(c *C) {
Expand Down Expand Up @@ -378,6 +378,15 @@ func (s *testSuite) TestGlobalAndSessionBindingBothExist(c *C) {
" └─Selection_14 9990.00 cop not(isnull(test.t2.id))",
" └─TableScan_13 10000.00 cop table:t2, range:[-inf,+inf], keep order:false, stats:pseudo",
))

tk.MustExec("drop table if exists t")
tk.MustExec("create table t(a int, b int, index idx(a))")
tk.MustExec("create global binding for select * from t where a > 10 using select * from t ignore index(idx) where a > 10")
// Should not panic for `-1`.
tk.MustContains("select * from t where a > -1", "TableReader")
// Session bindings should be able to cover the global bindings.
tk.MustExec("drop session binding for select * from t where a > 10")
tk.MustIndexLookup("select * from t where a > -1")
}

func (s *testSuite) TestExplain(c *C) {
Expand Down
4 changes: 2 additions & 2 deletions bindinfo/cache.go
Expand Up @@ -16,7 +16,6 @@ package bindinfo
import (
"unsafe"

"github.com/pingcap/parser/ast"
"github.com/pingcap/tidb/metrics"
"github.com/pingcap/tidb/types"
"github.com/pingcap/tidb/util/chunk"
Expand All @@ -34,7 +33,8 @@ const (
// BindMeta stores the basic bind info and bindSql astNode.
type BindMeta struct {
*BindRecord
Ast ast.StmtNode //ast will be used to do query sql bind check
// Hint is the parsed hints, it is used to bind hints to stmt node.
Hint *HintsSet
}

// cache is a k-v map, key is original sql, value is a slice of BindMeta.
Expand Down
2 changes: 1 addition & 1 deletion bindinfo/handle.go
Expand Up @@ -309,7 +309,7 @@ func (h *BindHandle) newBindMeta(record *BindRecord) (hash string, meta *BindMet
if err != nil {
return "", nil, err
}
meta = &BindMeta{BindRecord: record, Ast: stmtNodes[0]}
meta = &BindMeta{BindRecord: record, Hint: CollectHint(stmtNodes[0])}
return hash, meta, nil
}

Expand Down
2 changes: 1 addition & 1 deletion bindinfo/session_handle.go
Expand Up @@ -50,7 +50,7 @@ func (h *SessionHandle) newBindMeta(record *BindRecord) (hash string, meta *Bind
if err != nil {
return "", nil, err
}
meta = &BindMeta{BindRecord: record, Ast: stmtNodes[0]}
meta = &BindMeta{BindRecord: record, Hint: CollectHint(stmtNodes[0])}
return hash, meta, nil
}

Expand Down
18 changes: 16 additions & 2 deletions executor/bind.go
Expand Up @@ -15,6 +15,7 @@ package executor

import (
"context"
"fmt"

"github.com/opentracing/opentracing-go"
"github.com/pingcap/errors"
Expand All @@ -23,6 +24,7 @@ import (
"github.com/pingcap/tidb/domain"
plannercore "github.com/pingcap/tidb/planner/core"
"github.com/pingcap/tidb/util/chunk"
"github.com/pingcap/tidb/util/sqlexec"
)

// SQLBindExec represents a bind executor.
Expand All @@ -49,7 +51,7 @@ func (e *SQLBindExec) Next(ctx context.Context, req *chunk.Chunk) error {
req.Reset()
switch e.sqlBindOp {
case plannercore.OpSQLBindCreate:
return e.createSQLBind()
return e.createSQLBind(ctx)
case plannercore.OpSQLBindDrop:
return e.dropSQLBind()
default:
Expand All @@ -70,7 +72,19 @@ func (e *SQLBindExec) dropSQLBind() error {
return domain.GetDomain(e.ctx).BindHandle().DropBindRecord(record)
}

func (e *SQLBindExec) createSQLBind() error {
func (e *SQLBindExec) createSQLBind(ctx context.Context) error {
sqlExec := e.ctx.(sqlexec.SQLExecutor)
// Use explain to check the validity of bind sql.
recordSets, err := sqlExec.Execute(ctx, fmt.Sprintf("explain %s", e.bindSQL))
if len(recordSets) > 0 {
if err1 := recordSets[0].Close(); err1 != nil {
return err1
}
}
if err != nil {
return err
}

record := &bindinfo.BindRecord{
OriginalSQL: e.normdOrigSQL,
BindSQL: e.bindSQL,
Expand Down
8 changes: 3 additions & 5 deletions executor/compiler.go
Expand Up @@ -409,13 +409,11 @@ func addHintForSelect(hash, normdOrigSQL string, ctx sessionctx.Context, stmt as
bindRecord = sessionHandle.GetBindRecord(normdOrigSQL, "")
}
if bindRecord != nil {
if bindRecord.Status == bindinfo.Invalid {
return stmt
}
if bindRecord.Status == bindinfo.Using {
metrics.BindUsageCounter.WithLabelValues(metrics.ScopeSession).Inc()
return bindinfo.BindHint(stmt, bindRecord.Ast)
return bindinfo.BindHint(stmt, bindRecord.Hint)
}
return stmt
}
globalHandle := domain.GetDomain(ctx).BindHandle()
bindRecord = globalHandle.GetBindRecord(hash, normdOrigSQL, ctx.GetSessionVars().CurrentDB)
Expand All @@ -424,7 +422,7 @@ func addHintForSelect(hash, normdOrigSQL string, ctx sessionctx.Context, stmt as
}
if bindRecord != nil {
metrics.BindUsageCounter.WithLabelValues(metrics.ScopeGlobal).Inc()
return bindinfo.BindHint(stmt, bindRecord.Ast)
return bindinfo.BindHint(stmt, bindRecord.Hint)
}
return stmt
}
28 changes: 12 additions & 16 deletions util/testkit/testkit.go
Expand Up @@ -190,32 +190,28 @@ func (tk *TestKit) MustUseIndex(sql string, index string, args ...interface{}) b
return false
}

// MustIndexLookup checks whether the plan for the sql is Point_Get.
func (tk *TestKit) MustIndexLookup(sql string, args ...interface{}) *Result {
// MustContains checks whether the plan for the sql contains specific operator.
func (tk *TestKit) MustContains(sql string, op string, args ...interface{}) *Result {
rs := tk.MustQuery("explain "+sql, args...)
hasIndexLookup := false
has := false
for i := range rs.rows {
if strings.Contains(rs.rows[i][0], "IndexLookUp") {
hasIndexLookup = true
if strings.Contains(rs.rows[i][0], op) {
has = true
break
}
}
tk.c.Assert(hasIndexLookup, check.IsTrue)
tk.c.Assert(has, check.IsTrue)
return tk.MustQuery(sql, args...)
}

// MustIndexLookup checks whether the plan for the sql is Point_Get.
func (tk *TestKit) MustIndexLookup(sql string, args ...interface{}) *Result {
return tk.MustContains(sql, "IndexLookUp", args...)
}

// MustTableDual checks whether the plan for the sql is TableDual.
func (tk *TestKit) MustTableDual(sql string, args ...interface{}) *Result {
rs := tk.MustQuery("explain "+sql, args...)
hasTableDual := false
for i := range rs.rows {
if strings.Contains(rs.rows[i][0], "TableDual") {
hasTableDual = true
break
}
}
tk.c.Assert(hasTableDual, check.IsTrue)
return tk.MustQuery(sql, args...)
return tk.MustContains(sql, "TableDual", args...)
}

// MustPointGet checks whether the plan for the sql is Point_Get.
Expand Down

0 comments on commit 1cc43bb

Please sign in to comment.