Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

bindinfo: fix bugs when using bindings #14263

Merged
merged 4 commits into from Feb 4, 2020
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
185 changes: 46 additions & 139 deletions bindinfo/bind.go
Expand Up @@ -15,155 +15,62 @@ package bindinfo

import "github.com/pingcap/parser/ast"

// BindHint will add hints for originStmt according to hintedStmt' hints.
func BindHint(originStmt, hintedStmt ast.StmtNode) ast.StmtNode {
switch x := originStmt.(type) {
case *ast.SelectStmt:
return selectBind(x, hintedStmt.(*ast.SelectStmt))
default:
return originStmt
}
}

func selectBind(originalNode, hintedNode *ast.SelectStmt) *ast.SelectStmt {
if hintedNode.TableHints != nil {
originalNode.TableHints = hintedNode.TableHints
}
if originalNode.From != nil {
originalNode.From.TableRefs = resultSetNodeBind(originalNode.From.TableRefs, hintedNode.From.TableRefs).(*ast.Join)
}
if originalNode.Where != nil {
originalNode.Where = exprBind(originalNode.Where, hintedNode.Where).(ast.ExprNode)
}

if originalNode.Having != nil {
originalNode.Having.Expr = exprBind(originalNode.Having.Expr, hintedNode.Having.Expr)
}

if originalNode.OrderBy != nil {
originalNode.OrderBy = orderByBind(originalNode.OrderBy, hintedNode.OrderBy)
}

if originalNode.Fields != nil {
origFields := originalNode.Fields.Fields
hintFields := hintedNode.Fields.Fields
for idx := range origFields {
origFields[idx].Expr = exprBind(origFields[idx].Expr, hintFields[idx].Expr)
}
}
return originalNode
// HintsSet contains all hints of a query.
type HintsSet struct {
tableHints [][]*ast.TableOptimizerHint // Slice offset is the traversal order of `SelectStmt` in the ast.
indexHints [][]*ast.IndexHint // Slice offset is the traversal order of `TableName` in the ast.
}

func orderByBind(originalNode, hintedNode *ast.OrderByClause) *ast.OrderByClause {
for idx := 0; idx < len(originalNode.Items); idx++ {
originalNode.Items[idx].Expr = exprBind(originalNode.Items[idx].Expr, hintedNode.Items[idx].Expr)
}
return originalNode
type hintProcessor struct {
*HintsSet
// bindHint2Ast indicates the behavior of the processor, `true` for bind hint to ast, `false` for extract hint from ast.
bindHint2Ast bool
tableCounter int
indexCounter int
}

func exprBind(originalNode, hintedNode ast.ExprNode) ast.ExprNode {
switch v := originalNode.(type) {
case *ast.SubqueryExpr:
if v.Query != nil {
v.Query = resultSetNodeBind(v.Query, hintedNode.(*ast.SubqueryExpr).Query)
}
case *ast.ExistsSubqueryExpr:
if v.Sel != nil {
v.Sel.(*ast.SubqueryExpr).Query = resultSetNodeBind(v.Sel.(*ast.SubqueryExpr).Query, hintedNode.(*ast.ExistsSubqueryExpr).Sel.(*ast.SubqueryExpr).Query)
}
case *ast.PatternInExpr:
if v.Sel != nil {
v.Sel.(*ast.SubqueryExpr).Query = resultSetNodeBind(v.Sel.(*ast.SubqueryExpr).Query, hintedNode.(*ast.PatternInExpr).Sel.(*ast.SubqueryExpr).Query)
}
case *ast.BinaryOperationExpr:
if v.L != nil {
v.L = exprBind(v.L, hintedNode.(*ast.BinaryOperationExpr).L)
}
if v.R != nil {
v.R = exprBind(v.R, hintedNode.(*ast.BinaryOperationExpr).R)
}
case *ast.IsNullExpr:
if v.Expr != nil {
v.Expr = exprBind(v.Expr, hintedNode.(*ast.IsNullExpr).Expr)
}
case *ast.IsTruthExpr:
if v.Expr != nil {
v.Expr = exprBind(v.Expr, hintedNode.(*ast.IsTruthExpr).Expr)
}
case *ast.PatternLikeExpr:
if v.Pattern != nil {
v.Pattern = exprBind(v.Pattern, hintedNode.(*ast.PatternLikeExpr).Pattern)
}
case *ast.CompareSubqueryExpr:
if v.L != nil {
v.L = exprBind(v.L, hintedNode.(*ast.CompareSubqueryExpr).L)
}
if v.R != nil {
v.R = exprBind(v.R, hintedNode.(*ast.CompareSubqueryExpr).R)
}
case *ast.BetweenExpr:
if v.Left != nil {
v.Left = exprBind(v.Left, hintedNode.(*ast.BetweenExpr).Left)
}
if v.Right != nil {
v.Right = exprBind(v.Right, hintedNode.(*ast.BetweenExpr).Right)
}
case *ast.UnaryOperationExpr:
if v.V != nil {
v.V = exprBind(v.V, hintedNode.(*ast.UnaryOperationExpr).V)
}
case *ast.CaseExpr:
if v.Value != nil {
v.Value = exprBind(v.Value, hintedNode.(*ast.CaseExpr).Value)
}
if v.ElseClause != nil {
v.ElseClause = exprBind(v.ElseClause, hintedNode.(*ast.CaseExpr).ElseClause)
func (hp *hintProcessor) Enter(in ast.Node) (ast.Node, bool) {
switch v := in.(type) {
case *ast.SelectStmt:
if hp.bindHint2Ast {
if hp.tableCounter < len(hp.tableHints) {
v.TableHints = hp.tableHints[hp.tableCounter]
} else {
v.TableHints = nil
}
hp.tableCounter++
} else {
hp.tableHints = append(hp.tableHints, v.TableHints)
}
case *ast.TableName:
if hp.bindHint2Ast {
if hp.indexCounter < len(hp.indexHints) {
v.IndexHints = hp.indexHints[hp.indexCounter]
} else {
v.IndexHints = nil
}
hp.indexCounter++
} else {
hp.indexHints = append(hp.indexHints, v.IndexHints)
}
}
return originalNode
return in, false
}

func resultSetNodeBind(originalNode, hintedNode ast.ResultSetNode) ast.ResultSetNode {
switch x := originalNode.(type) {
case *ast.Join:
return joinBind(x, hintedNode.(*ast.Join))
case *ast.TableSource:
ts, _ := hintedNode.(*ast.TableSource)
switch v := x.Source.(type) {
case *ast.SelectStmt:
x.Source = selectBind(v, ts.Source.(*ast.SelectStmt))
case *ast.UnionStmt:
x.Source = unionSelectBind(v, hintedNode.(*ast.TableSource).Source.(*ast.UnionStmt))
case *ast.TableName:
x.Source.(*ast.TableName).IndexHints = ts.Source.(*ast.TableName).IndexHints
}
return x
case *ast.SelectStmt:
return selectBind(x, hintedNode.(*ast.SelectStmt))
case *ast.UnionStmt:
return unionSelectBind(x, hintedNode.(*ast.UnionStmt))
default:
return x
}
func (hp *hintProcessor) Leave(in ast.Node) (ast.Node, bool) {
return in, true
}

func joinBind(originalNode, hintedNode *ast.Join) *ast.Join {
if originalNode.Left != nil {
originalNode.Left = resultSetNodeBind(originalNode.Left, hintedNode.Left)
}

if hintedNode.Right != nil {
originalNode.Right = resultSetNodeBind(originalNode.Right, hintedNode.Right)
}

return originalNode
// CollectHint collects hints for a statement.
func CollectHint(in ast.StmtNode) *HintsSet {
hp := hintProcessor{HintsSet: &HintsSet{tableHints: make([][]*ast.TableOptimizerHint, 0, 4), indexHints: make([][]*ast.IndexHint, 0, 4)}}
in.Accept(&hp)
return hp.HintsSet
}

func unionSelectBind(originalNode, hintedNode *ast.UnionStmt) ast.ResultSetNode {
selects := originalNode.SelectList.Selects
for i := len(selects) - 1; i >= 0; i-- {
originalNode.SelectList.Selects[i] = selectBind(selects[i], hintedNode.SelectList.Selects[i])
}

return originalNode
// BindHint will add hints for stmt according to the hints in `hintsSet`.
func BindHint(stmt ast.StmtNode, hintsSet *HintsSet) ast.StmtNode {
hp := hintProcessor{HintsSet: hintsSet, bindHint2Ast: true}
stmt.Accept(&hp)
return stmt
}
13 changes: 11 additions & 2 deletions bindinfo/bind_test.go
Expand Up @@ -137,9 +137,9 @@ func (s *testSuite) TestBindParse(c *C) {
c.Check(bindData.UpdateTime, NotNil)

// Test fields with quotes or slashes.
sql = `CREATE GLOBAL BINDING FOR select * from t where a BETWEEN "a" and "b" USING select * from t use index(idx) where a BETWEEN "a\nb\rc\td\0e" and 'x'`
sql = `CREATE GLOBAL BINDING FOR select * from t where i BETWEEN "a" and "b" USING select * from t use index(index_t) where i BETWEEN "a\nb\rc\td\0e" and 'x'`
tk.MustExec(sql)
tk.MustExec(`DROP global binding for select * from t use index(idx) where a BETWEEN "a\nb\rc\td\0e" and "x"`)
tk.MustExec(`DROP global binding for select * from t use index(index_t) where i BETWEEN "a\nb\rc\td\0e" and "x"`)
}

func (s *testSuite) TestGlobalBinding(c *C) {
Expand Down Expand Up @@ -378,6 +378,15 @@ func (s *testSuite) TestGlobalAndSessionBindingBothExist(c *C) {
" └─Selection_14 9990.00 cop not(isnull(test.t2.id))",
" └─TableScan_13 10000.00 cop table:t2, range:[-inf,+inf], keep order:false, stats:pseudo",
))

tk.MustExec("drop table if exists t")
tk.MustExec("create table t(a int, b int, index idx(a))")
tk.MustExec("create global binding for select * from t where a > 10 using select * from t ignore index(idx) where a > 10")
// Should not panic for `-1`.
tk.MustContains("select * from t where a > -1", "TableReader")
// Session bindings should be able to cover the global bindings.
tk.MustExec("drop session binding for select * from t where a > 10")
tk.MustIndexLookup("select * from t where a > -1")
}

func (s *testSuite) TestExplain(c *C) {
Expand Down
4 changes: 2 additions & 2 deletions bindinfo/cache.go
Expand Up @@ -16,7 +16,6 @@ package bindinfo
import (
"unsafe"

"github.com/pingcap/parser/ast"
"github.com/pingcap/tidb/metrics"
"github.com/pingcap/tidb/types"
"github.com/pingcap/tidb/util/chunk"
Expand All @@ -34,7 +33,8 @@ const (
// BindMeta stores the basic bind info and bindSql astNode.
type BindMeta struct {
*BindRecord
Ast ast.StmtNode //ast will be used to do query sql bind check
// Hint is the parsed hints, it is used to bind hints to stmt node.
Hint *HintsSet
eurekaka marked this conversation as resolved.
Show resolved Hide resolved
}

// cache is a k-v map, key is original sql, value is a slice of BindMeta.
Expand Down
2 changes: 1 addition & 1 deletion bindinfo/handle.go
Expand Up @@ -309,7 +309,7 @@ func (h *BindHandle) newBindMeta(record *BindRecord) (hash string, meta *BindMet
if err != nil {
return "", nil, err
}
meta = &BindMeta{BindRecord: record, Ast: stmtNodes[0]}
meta = &BindMeta{BindRecord: record, Hint: CollectHint(stmtNodes[0])}
return hash, meta, nil
}

Expand Down
2 changes: 1 addition & 1 deletion bindinfo/session_handle.go
Expand Up @@ -50,7 +50,7 @@ func (h *SessionHandle) newBindMeta(record *BindRecord) (hash string, meta *Bind
if err != nil {
return "", nil, err
}
meta = &BindMeta{BindRecord: record, Ast: stmtNodes[0]}
meta = &BindMeta{BindRecord: record, Hint: CollectHint(stmtNodes[0])}
return hash, meta, nil
}

Expand Down
18 changes: 16 additions & 2 deletions executor/bind.go
Expand Up @@ -15,6 +15,7 @@ package executor

import (
"context"
"fmt"

"github.com/opentracing/opentracing-go"
"github.com/pingcap/errors"
Expand All @@ -23,6 +24,7 @@ import (
"github.com/pingcap/tidb/domain"
plannercore "github.com/pingcap/tidb/planner/core"
"github.com/pingcap/tidb/util/chunk"
"github.com/pingcap/tidb/util/sqlexec"
)

// SQLBindExec represents a bind executor.
Expand All @@ -48,7 +50,7 @@ func (e *SQLBindExec) Next(ctx context.Context, req *chunk.Chunk) error {
req.Reset()
switch e.sqlBindOp {
case plannercore.OpSQLBindCreate:
return e.createSQLBind()
return e.createSQLBind(ctx)
case plannercore.OpSQLBindDrop:
return e.dropSQLBind()
default:
Expand All @@ -69,7 +71,19 @@ func (e *SQLBindExec) dropSQLBind() error {
return domain.GetDomain(e.ctx).BindHandle().DropBindRecord(record)
}

func (e *SQLBindExec) createSQLBind() error {
func (e *SQLBindExec) createSQLBind(ctx context.Context) error {
sqlExec := e.ctx.(sqlexec.SQLExecutor)
// Use explain to check the validity of bind sql.
recordSets, err := sqlExec.Execute(ctx, fmt.Sprintf("explain %s", e.bindSQL))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

explain is much more expensive than parsing and plan building?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, but it much more easier by using explain, and it could also check the privilege of select tables.

if len(recordSets) > 0 {
if err1 := recordSets[0].Close(); err1 != nil {
return err1
}
}
if err != nil {
return err
}

record := &bindinfo.BindRecord{
OriginalSQL: e.normdOrigSQL,
BindSQL: e.bindSQL,
Expand Down
8 changes: 3 additions & 5 deletions executor/compiler.go
Expand Up @@ -406,13 +406,11 @@ func addHintForSelect(hash, normdOrigSQL string, ctx sessionctx.Context, stmt as
sessionHandle := ctx.Value(bindinfo.SessionBindInfoKeyType).(*bindinfo.SessionHandle)
bindRecord := sessionHandle.GetBindRecord(normdOrigSQL, ctx.GetSessionVars().CurrentDB)
if bindRecord != nil {
if bindRecord.Status == bindinfo.Invalid {
return stmt
}
if bindRecord.Status == bindinfo.Using {
metrics.BindUsageCounter.WithLabelValues(metrics.ScopeSession).Inc()
return bindinfo.BindHint(stmt, bindRecord.Ast)
return bindinfo.BindHint(stmt, bindRecord.Hint)
}
return stmt
}
globalHandle := domain.GetDomain(ctx).BindHandle()
bindRecord = globalHandle.GetBindRecord(hash, normdOrigSQL, ctx.GetSessionVars().CurrentDB)
Expand All @@ -421,7 +419,7 @@ func addHintForSelect(hash, normdOrigSQL string, ctx sessionctx.Context, stmt as
}
if bindRecord != nil {
metrics.BindUsageCounter.WithLabelValues(metrics.ScopeGlobal).Inc()
return bindinfo.BindHint(stmt, bindRecord.Ast)
return bindinfo.BindHint(stmt, bindRecord.Hint)
}
return stmt
}
28 changes: 12 additions & 16 deletions util/testkit/testkit.go
Expand Up @@ -190,32 +190,28 @@ func (tk *TestKit) MustUseIndex(sql string, index string, args ...interface{}) b
return false
}

// MustIndexLookup checks whether the plan for the sql is Point_Get.
func (tk *TestKit) MustIndexLookup(sql string, args ...interface{}) *Result {
// MustContains checks whether the plan for the sql contains specific operator.
func (tk *TestKit) MustContains(sql string, op string, args ...interface{}) *Result {
rs := tk.MustQuery("explain "+sql, args...)
hasIndexLookup := false
has := false
for i := range rs.rows {
if strings.Contains(rs.rows[i][0], "IndexLookUp") {
hasIndexLookup = true
if strings.Contains(rs.rows[i][0], op) {
has = true
break
}
}
tk.c.Assert(hasIndexLookup, check.IsTrue)
tk.c.Assert(has, check.IsTrue)
return tk.MustQuery(sql, args...)
}

// MustIndexLookup checks whether the plan for the sql is Point_Get.
func (tk *TestKit) MustIndexLookup(sql string, args ...interface{}) *Result {
return tk.MustContains(sql, "IndexLookUp", args...)
}

// MustTableDual checks whether the plan for the sql is TableDual.
func (tk *TestKit) MustTableDual(sql string, args ...interface{}) *Result {
rs := tk.MustQuery("explain "+sql, args...)
hasTableDual := false
for i := range rs.rows {
if strings.Contains(rs.rows[i][0], "TableDual") {
hasTableDual = true
break
}
}
tk.c.Assert(hasTableDual, check.IsTrue)
return tk.MustQuery(sql, args...)
return tk.MustContains(sql, "TableDual", args...)
}

// MustPointGet checks whether the plan for the sql is Point_Get.
Expand Down