Skip to content

Commit

Permalink
SQL digester: Modify sql digester logic for IN statement with single …
Browse files Browse the repository at this point in the history
…literal `IN (?) => IN (...)` (pingcap#44601)

ref pingcap#44298
  • Loading branch information
isabella0428 authored and yibin87 committed Oct 31, 2023
1 parent 94d2421 commit cf282b3
Show file tree
Hide file tree
Showing 7 changed files with 181 additions and 18 deletions.
2 changes: 1 addition & 1 deletion bindinfo/internal/testutil.go
Expand Up @@ -35,6 +35,6 @@ func UtilNormalizeWithDefaultDB(t *testing.T, sql string) (normalized, digest st
testParser := parser.New()
stmt, err := testParser.ParseOneStmt(sql, "", "")
require.NoError(t, err)
normalized, digestResult := parser.NormalizeDigest(utilparser.RestoreWithDefaultDB(stmt, "test", ""))
normalized, digestResult := parser.NormalizeDigestForBinding(utilparser.RestoreWithDefaultDB(stmt, "test", ""))
return normalized, digestResult.String()
}
2 changes: 1 addition & 1 deletion bindinfo/tests/BUILD.bazel
Expand Up @@ -9,7 +9,7 @@ go_test(
],
flaky = True,
race = "on",
shard_count = 34,
shard_count = 35,
deps = [
"//bindinfo",
"//bindinfo/internal",
Expand Down
50 changes: 47 additions & 3 deletions bindinfo/tests/bind_test.go
Expand Up @@ -414,7 +414,7 @@ func TestBindingSymbolList(t *testing.T) {
require.True(t, tk.MustUseIndex("select a, b from t where a = 3 limit 1, 100", "ib(b)"))

// Normalize
sql, hash := parser.NormalizeDigest("select a, b from test . t where a = 1 limit 0, 1")
sql, hash := parser.NormalizeDigestForBinding("select a, b from test . t where a = 1 limit 0, 1")

bindData := dom.BindHandle().GetBindRecord(hash.String(), sql, "test")
require.NotNil(t, bindData)
Expand All @@ -429,6 +429,50 @@ func TestBindingSymbolList(t *testing.T) {
require.NotNil(t, bind.UpdateTime)
}

// TestBindingInListWithSingleLiteral tests sql with "IN (Lit)", fixes #44298
func TestBindingInListWithSingleLiteral(t *testing.T) {
store, dom := testkit.CreateMockStoreAndDomain(t)

tk := testkit.NewTestKit(t, store)
tk.MustExec("use test")
tk.MustExec("drop table if exists t")
tk.MustExec("create table t(a int, b int, INDEX ia (a), INDEX ib (b));")
tk.MustExec("insert into t value(1, 1);")

// GIVEN
sqlcmd := "select a, b from t where a in (1)"
binding := `create global binding for select a, b from t where a in (1, 2, 3) using select a, b from t use index (ib) where a in (1, 2, 3)`

// before binding
tk.MustQuery(sqlcmd)
require.Equal(t, "t:ia", tk.Session().GetSessionVars().StmtCtx.IndexNames[0])
require.True(t, tk.MustUseIndex(sqlcmd, "ia(a)"))

tk.MustExec(binding)

// after binding
tk.MustQuery(sqlcmd)
require.Equal(t, "t:ib", tk.Session().GetSessionVars().StmtCtx.IndexNames[0])
require.True(t, tk.MustUseIndex(sqlcmd, "ib(b)"))

tk.MustQuery("select @@last_plan_from_binding").Check(testkit.Rows("1"))

// Normalize
sql, hash := parser.NormalizeDigestForBinding("select a, b from test . t where a in (1)")

bindData := dom.BindHandle().GetBindRecord(hash.String(), sql, "test")
require.NotNil(t, bindData)
require.Equal(t, "select `a` , `b` from `test` . `t` where `a` in ( ... )", bindData.OriginalSQL)
bind := bindData.Bindings[0]
require.Equal(t, "SELECT `a`,`b` FROM `test`.`t` USE INDEX (`ib`) WHERE `a` IN (1,2,3)", bind.BindSQL)
require.Equal(t, "test", bindData.Db)
require.Equal(t, bindinfo.Enabled, bind.Status)
require.NotNil(t, bind.Charset)
require.NotNil(t, bind.Collation)
require.NotNil(t, bind.CreateTime)
require.NotNil(t, bind.UpdateTime)
}

func TestDMLSQLBind(t *testing.T) {
store := testkit.CreateMockStore(t)

Expand Down Expand Up @@ -538,7 +582,7 @@ func TestErrorBind(t *testing.T) {
_, err := tk.Exec("create global binding for select * from t where i>100 using select * from t use index(index_t) where i>100")
require.NoError(t, err, "err %v", err)

sql, hash := parser.NormalizeDigest("select * from test . t where i > ?")
sql, hash := parser.NormalizeDigestForBinding("select * from test . t where i > ?")
bindData := dom.BindHandle().GetBindRecord(hash.String(), sql, "test")
require.NotNil(t, bindData)
require.Equal(t, "select * from `test` . `t` where `i` > ?", bindData.OriginalSQL)
Expand Down Expand Up @@ -1304,7 +1348,7 @@ func TestBindSQLDigest(t *testing.T) {
parser4binding := parser.New()
originNode, err := parser4binding.ParseOneStmt(c.origin, "utf8mb4", "utf8mb4_general_ci")
require.NoError(t, err)
_, sqlDigestWithDB := parser.NormalizeDigest(utilparser.RestoreWithDefaultDB(originNode, "test", c.origin))
_, sqlDigestWithDB := parser.NormalizeDigestForBinding(utilparser.RestoreWithDefaultDB(originNode, "test", c.origin))
require.Equal(t, res[0][9], sqlDigestWithDB.String())
}
}
Expand Down
80 changes: 76 additions & 4 deletions parser/digester.go
Expand Up @@ -90,6 +90,18 @@ func Normalize(sql string) (result string) {
return
}

// NormalizeForBinding generates the normalized statements with additional binding rules
// it will get normalized form of statement text
// which removes general property of a statement but keeps specific property.
//
// for example: NormalizeForBinding('select 1 from b where a = 1') => 'select ? from b where a = ?'
func NormalizeForBinding(sql string) (result string) {
d := digesterPool.Get().(*sqlDigester)
result = d.doNormalizeForBinding(sql, false)
digesterPool.Put(d)
return
}

// NormalizeKeepHint generates the normalized statements, but keep the hints.
// it will get normalized form of statement text with hints.
// which removes general property of a statement but keeps specific property.
Expand All @@ -110,6 +122,14 @@ func NormalizeDigest(sql string) (normalized string, digest *Digest) {
return
}

// NormalizeDigestForBinding combines Normalize and DigestNormalized into one method with additional binding rules.
func NormalizeDigestForBinding(sql string) (normalized string, digest *Digest) {
d := digesterPool.Get().(*sqlDigester)
normalized, digest = d.doNormalizeDigestForBinding(sql)
digesterPool.Put(d)
return
}

var digesterPool = sync.Pool{
New: func() interface{} {
return &sqlDigester{
Expand Down Expand Up @@ -141,7 +161,7 @@ func (d *sqlDigester) doDigestNormalized(normalized string) (digest *Digest) {
}

func (d *sqlDigester) doDigest(sql string) (digest *Digest) {
d.normalize(sql, false)
d.normalize(sql, false, false)
d.hasher.Write(d.buffer.Bytes())
d.buffer.Reset()
digest = NewDigest(d.hasher.Sum(nil))
Expand All @@ -150,14 +170,31 @@ func (d *sqlDigester) doDigest(sql string) (digest *Digest) {
}

func (d *sqlDigester) doNormalize(sql string, keepHint bool) (result string) {
d.normalize(sql, keepHint)
d.normalize(sql, keepHint, false)
result = d.buffer.String()
d.buffer.Reset()
return
}

func (d *sqlDigester) doNormalizeForBinding(sql string, keepHint bool) (result string) {
d.normalize(sql, keepHint, true)
result = d.buffer.String()
d.buffer.Reset()
return
}

func (d *sqlDigester) doNormalizeDigest(sql string) (normalized string, digest *Digest) {
d.normalize(sql, false)
d.normalize(sql, false, false)
normalized = d.buffer.String()
d.hasher.Write(d.buffer.Bytes())
d.buffer.Reset()
digest = NewDigest(d.hasher.Sum(nil))
d.hasher.Reset()
return
}

func (d *sqlDigester) doNormalizeDigestForBinding(sql string) (normalized string, digest *Digest) {
d.normalize(sql, false, true)
normalized = d.buffer.String()
d.hasher.Write(d.buffer.Bytes())
d.buffer.Reset()
Expand All @@ -175,7 +212,7 @@ const (
genericSymbolList = -2
)

func (d *sqlDigester) normalize(sql string, keepHint bool) {
func (d *sqlDigester) normalize(sql string, keepHint bool, forBinding bool) {
d.lexer.reset(sql)
d.lexer.setKeepHint(keepHint)
for {
Expand All @@ -194,6 +231,12 @@ func (d *sqlDigester) normalize(sql string, keepHint bool) {

d.reduceLit(&currTok)

// Apply binding matching specific rules
if forBinding {
// IN (?) => IN ( ... ) #44298
d.reduceInListWithSingleLiteral(&currTok)
}

if currTok.tok == identifier {
if strings.HasPrefix(currTok.lit, "_") {
_, err := charset.GetCharsetInfo(currTok.lit[1:])
Expand Down Expand Up @@ -334,6 +377,20 @@ func (d *sqlDigester) isGenericLists(last4 []token) bool {
return true
}

// IN (?) => IN (...) Issue: #44298
func (d *sqlDigester) reduceInListWithSingleLiteral(currTok *token) {
last3 := d.tokens.back(3)
if len(last3) == 3 &&
d.isInKeyword(last3[0]) &&
d.isLeftParen(last3[1]) &&
last3[2].tok == genericSymbol &&
d.isRightParen(*currTok) {
d.tokens.popBack(1)
d.tokens.pushBack(token{genericSymbolList, "..."})
return
}
}

func (d *sqlDigester) isPrefixByUnary(currTok int) (isUnary bool) {
if !d.isNumLit(currTok) {
return
Expand Down Expand Up @@ -442,6 +499,21 @@ func (*sqlDigester) isComma(tok token) (isComma bool) {
return
}

func (*sqlDigester) isLeftParen(tok token) (isLeftParen bool) {
isLeftParen = tok.lit == "("
return
}

func (*sqlDigester) isRightParen(tok token) (isLeftParen bool) {
isLeftParen = tok.lit == ")"
return
}

func (*sqlDigester) isInKeyword(tok token) (isInKeyword bool) {
isInKeyword = tok.lit == "in"
return
}

type token struct {
tok int
lit string
Expand Down
27 changes: 25 additions & 2 deletions parser/digester_test.go
Expand Up @@ -24,10 +24,11 @@ import (
)

func TestNormalize(t *testing.T) {
tests := []struct {
tests_for_generic_normalization_rules := []struct {
input string
expect string
}{
// Generic normalization rules
{"select _utf8mb4'123'", "select (_charset) ?"},
{"SELECT 1", "select ?"},
{"select null", "select ?"},
Expand Down Expand Up @@ -68,7 +69,7 @@ func TestNormalize(t *testing.T) {
{"insert into t values (1), (2)", "insert into `t` values ( ... )"},
{"insert into t values (1)", "insert into `t` values ( ? )"},
}
for _, test := range tests {
for _, test := range tests_for_generic_normalization_rules {
normalized := parser.Normalize(test.input)
digest := parser.DigestNormalized(normalized)
require.Equal(t, test.expect, normalized)
Expand All @@ -77,6 +78,28 @@ func TestNormalize(t *testing.T) {
require.Equal(t, normalized, normalized2)
require.Equalf(t, digest.String(), digest2.String(), "%+v", test)
}

tests_for_binding_specific_rules := []struct {
input string
expect string
}{
// Binding specific rules
// IN (Lit) => IN ( ... ) #44298
{"select * from t where a in (1)", "select * from `t` where `a` in ( ... )"},
{"select * from t where (a, b) in ((1, 1))", "select * from `t` where ( `a` , `b` ) in ( ( ... ) )"},
{"select * from t where (a, b) in ((1, 1), (2, 2))", "select * from `t` where ( `a` , `b` ) in ( ( ... ) )"},
{"select * from t where a in(1, 2)", "select * from `t` where `a` in ( ... )"},
{"select * from t where a in(1, 2, 3)", "select * from `t` where `a` in ( ... )"},
}
for _, test := range tests_for_binding_specific_rules {
normalized := parser.NormalizeForBinding(test.input)
digest := parser.DigestNormalized(normalized)
require.Equal(t, test.expect, normalized)

normalized2, digest2 := parser.NormalizeDigestForBinding(test.input)
require.Equal(t, normalized, normalized2)
require.Equalf(t, digest.String(), digest2.String(), "%+v", test)
}
}

func TestNormalizeKeepHint(t *testing.T) {
Expand Down
4 changes: 2 additions & 2 deletions planner/core/plan_cache_utils.go
Expand Up @@ -58,7 +58,7 @@ var (
PreparedPlanCacheMaxMemory = *atomic2.NewUint64(math.MaxUint64)

// ExtractSelectAndNormalizeDigest extract the select statement and normalize it.
ExtractSelectAndNormalizeDigest func(stmtNode ast.StmtNode, specifiledDB string) (ast.StmtNode, string, string, error)
ExtractSelectAndNormalizeDigest func(stmtNode ast.StmtNode, specifiledDB string, forBinding bool) (ast.StmtNode, string, string, error)
)

type paramMarkerExtractor struct {
Expand Down Expand Up @@ -145,7 +145,7 @@ func GeneratePlanCacheStmtWithAST(ctx context.Context, sctx sessionctx.Context,
if !cacheable {
sctx.GetSessionVars().StmtCtx.AppendWarning(errors.Errorf("skip prepared plan-cache: " + reason))
}
selectStmtNode, normalizedSQL4PC, digest4PC, err = ExtractSelectAndNormalizeDigest(paramStmt, vars.CurrentDB)
selectStmtNode, normalizedSQL4PC, digest4PC, err = ExtractSelectAndNormalizeDigest(paramStmt, vars.CurrentDB, false)
if err != nil || selectStmtNode == nil {
normalizedSQL4PC = ""
digest4PC = ""
Expand Down
34 changes: 29 additions & 5 deletions planner/optimize.go
Expand Up @@ -571,7 +571,7 @@ func buildLogicalPlan(ctx context.Context, sctx sessionctx.Context, node ast.Nod
}

// ExtractSelectAndNormalizeDigest extract the select statement and normalize it.
func ExtractSelectAndNormalizeDigest(stmtNode ast.StmtNode, specifiledDB string) (ast.StmtNode, string, string, error) {
func ExtractSelectAndNormalizeDigest(stmtNode ast.StmtNode, specifiledDB string, forBinding bool) (ast.StmtNode, string, string, error) {
switch x := stmtNode.(type) {
case *ast.ExplainStmt:
// This function is only used to find bind record.
Expand All @@ -584,18 +584,33 @@ func ExtractSelectAndNormalizeDigest(stmtNode ast.StmtNode, specifiledDB string)
}
switch x.Stmt.(type) {
case *ast.SelectStmt, *ast.DeleteStmt, *ast.UpdateStmt, *ast.InsertStmt:
normalizeSQL := parser.Normalize(utilparser.RestoreWithDefaultDB(x.Stmt, specifiledDB, x.Text()))
var normalizeSQL string
if forBinding {
// Apply additional binding rules if enabled
normalizeSQL = parser.NormalizeForBinding(utilparser.RestoreWithDefaultDB(x.Stmt, specifiledDB, x.Text()))
} else {
normalizeSQL = parser.Normalize(utilparser.RestoreWithDefaultDB(x.Stmt, specifiledDB, x.Text()))
}
normalizeSQL = core.EraseLastSemicolonInSQL(normalizeSQL)
hash := parser.DigestNormalized(normalizeSQL)
return x.Stmt, normalizeSQL, hash.String(), nil
case *ast.SetOprStmt:
core.EraseLastSemicolon(x)
var normalizeExplainSQL string
var explainSQL string
if specifiledDB != "" {
normalizeExplainSQL = parser.Normalize(utilparser.RestoreWithDefaultDB(x, specifiledDB, x.Text()))
explainSQL = utilparser.RestoreWithDefaultDB(x, specifiledDB, x.Text())
} else {
explainSQL = x.Text()
}

if forBinding {
// Apply additional binding rules
normalizeExplainSQL = parser.NormalizeForBinding(explainSQL)
} else {
normalizeExplainSQL = parser.Normalize(x.Text())
}

idx := strings.Index(normalizeExplainSQL, "select")
parenthesesIdx := strings.Index(normalizeExplainSQL, "(")
if parenthesesIdx != -1 && parenthesesIdx < idx {
Expand All @@ -615,7 +630,16 @@ func ExtractSelectAndNormalizeDigest(stmtNode ast.StmtNode, specifiledDB string)
if len(x.Text()) == 0 {
return x, "", "", nil
}
normalizedSQL, hash := parser.NormalizeDigest(utilparser.RestoreWithDefaultDB(x, specifiledDB, x.Text()))

var normalizedSQL string
var hash *parser.Digest
if forBinding {
// Apply additional binding rules
normalizedSQL, hash = parser.NormalizeDigestForBinding(utilparser.RestoreWithDefaultDB(x, specifiledDB, x.Text()))
} else {
normalizedSQL, hash = parser.NormalizeDigest(utilparser.RestoreWithDefaultDB(x, specifiledDB, x.Text()))
}

return x, normalizedSQL, hash.String(), nil
}
return nil, "", "", nil
Expand All @@ -626,7 +650,7 @@ func getBindRecord(ctx sessionctx.Context, stmt ast.StmtNode) (*bindinfo.BindRec
if ctx.Value(bindinfo.SessionBindInfoKeyType) == nil {
return nil, "", nil
}
stmtNode, normalizedSQL, hash, err := ExtractSelectAndNormalizeDigest(stmt, ctx.GetSessionVars().CurrentDB)
stmtNode, normalizedSQL, hash, err := ExtractSelectAndNormalizeDigest(stmt, ctx.GetSessionVars().CurrentDB, true)
if err != nil || stmtNode == nil {
return nil, "", err
}
Expand Down

0 comments on commit cf282b3

Please sign in to comment.