From 62d74b841ca15f96f0567a2b41dfbd677824fb7e Mon Sep 17 00:00:00 2001 From: Andres Taylor Date: Wed, 2 Sep 2020 10:13:50 +0200 Subject: [PATCH 1/8] Rewrite joins written with the USING construct If a join has been expressed as: ``` FROM tblA JOIN tblB USING (col1,col2) ``` it will be rewritten to ``` FROM tblA JOIN tblB ON tblA.col1 = tblB.col1 AND tblA.col2 = tblB.col2 ``` This allows our planner to recognize these queries and plan them correctly. Signed-off-by: Andres Taylor --- go/vt/sqlparser/ast_funcs.go | 28 +++ go/vt/sqlparser/expression_rewriting.go | 42 +++++ go/vt/sqlparser/expression_rewriting_test.go | 171 ++++++++---------- go/vt/vtgate/planbuilder/builder.go | 4 +- .../planbuilder/testdata/from_cases.txt | 37 ++++ .../testdata/unsupported_cases.txt | 4 - 6 files changed, 189 insertions(+), 97 deletions(-) diff --git a/go/vt/sqlparser/ast_funcs.go b/go/vt/sqlparser/ast_funcs.go index 33517ad332a..f7673b613d1 100644 --- a/go/vt/sqlparser/ast_funcs.go +++ b/go/vt/sqlparser/ast_funcs.go @@ -21,6 +21,9 @@ import ( "encoding/json" "strings" + vtrpcpb "vitess.io/vitess/go/vt/proto/vtrpc" + "vitess.io/vitess/go/vt/vterrors" + "vitess.io/vitess/go/vt/log" "vitess.io/vitess/go/sqltypes" @@ -342,6 +345,20 @@ func (node *AliasedTableExpr) RemoveHints() *AliasedTableExpr { return &noHints } +//TableName returns a TableName pointing to this table expr +func (node *AliasedTableExpr) TableName() (TableName, error) { + if !node.As.IsEmpty() { + return TableName{Name: node.As}, nil + } + + tableName, ok := node.Expr.(TableName) + if !ok { + return TableName{}, vterrors.Errorf(vtrpcpb.Code_INTERNAL, "BUG: the AST has changed. This should not be possible") + } + + return tableName, nil +} + // IsEmpty returns true if TableName is nil or empty. func (node TableName) IsEmpty() bool { // If Name is empty, Qualifier is also empty. @@ -514,6 +531,17 @@ func NewColName(str string) *ColName { } } +// NewColNameWithQualifier makes a new ColName pointing to a specific table +func NewColNameWithQualifier(identifier string, table TableName) *ColName { + return &ColName{ + Name: NewColIdent(identifier), + Qualifier: TableName{ + Name: NewTableIdent(table.Name.String()), + Qualifier: NewTableIdent(table.Qualifier.String()), + }, + } +} + //NewSelect is used to create a select statement func NewSelect(comments Comments, exprs SelectExprs, selectOptions []string, from TableExprs, where *Where, groupBy GroupBy, having *Where) *Select { var cache *bool diff --git a/go/vt/sqlparser/expression_rewriting.go b/go/vt/sqlparser/expression_rewriting.go index 1ea4bae0db4..07895faf3e9 100644 --- a/go/vt/sqlparser/expression_rewriting.go +++ b/go/vt/sqlparser/expression_rewriting.go @@ -161,6 +161,48 @@ func (er *expressionRewriter) goingDown(cursor *Cursor) bool { cursor.Replace(bindVarExpression(UserDefinedVariableName + udv)) er.needBindVarFor(udv) } + + case JoinCondition: + if node.Using != nil { + joinTableExpr, ok := cursor.Parent().(*JoinTableExpr) + if !ok { + // this is not possible with the current AST + break + } + leftTable, leftOk := joinTableExpr.LeftExpr.(*AliasedTableExpr) + rightTable, rightOk := joinTableExpr.RightExpr.(*AliasedTableExpr) + if !(leftOk && rightOk) { + // we only deal with simple FROM A JOIN B USING queries at the moment + break + } + lft, err := leftTable.TableName() + if err != nil { + er.err = err + break + } + rgt, err := rightTable.TableName() + if err != nil { + er.err = err + break + } + newCondition := JoinCondition{} + for _, colIdent := range node.Using { + lftCol := NewColNameWithQualifier(colIdent.String(), lft) + rgtCol := NewColNameWithQualifier(colIdent.String(), rgt) + cmp := &ComparisonExpr{ + Operator: EqualStr, + Left: lftCol, + Right: rgtCol, + } + if newCondition.On == nil { + newCondition.On = cmp + } else { + newCondition.On = &AndExpr{Left: newCondition.On, Right: cmp} + } + } + cursor.Replace(newCondition) + } + } return true } diff --git a/go/vt/sqlparser/expression_rewriting_test.go b/go/vt/sqlparser/expression_rewriting_test.go index 110c4883373..fe6d90623d0 100644 --- a/go/vt/sqlparser/expression_rewriting_test.go +++ b/go/vt/sqlparser/expression_rewriting_test.go @@ -29,98 +29,85 @@ type myTestCase struct { } func TestRewrites(in *testing.T) { - tests := []myTestCase{ - { - in: "SELECT 42", - expected: "SELECT 42", - // no bindvar needs - }, - { - in: "SELECT last_insert_id()", - expected: "SELECT :__lastInsertId as `last_insert_id()`", - liid: true, - }, - { - in: "SELECT database()", - expected: "SELECT :__vtdbname as `database()`", - db: true, - }, - { - in: "SELECT database() from test", - expected: "SELECT database() from test", - // no bindvar needs - }, - { - in: "SELECT last_insert_id() as test", - expected: "SELECT :__lastInsertId as test", - liid: true, - }, - { - in: "SELECT last_insert_id() + database()", - expected: "SELECT :__lastInsertId + :__vtdbname as `last_insert_id() + database()`", - db: true, liid: true, - }, - { - in: "select (select database()) from test", - expected: "select (select database() from dual) from test", - // no bindvar needs - }, - { - in: "select (select database() from dual) from test", - expected: "select (select database() from dual) from test", - // no bindvar needs - }, - { - in: "select (select database() from dual) from dual", - expected: "select (select :__vtdbname as `database()` from dual) as `(select database() from dual)` from dual", - db: true, - }, - { - in: "select id from user where database()", - expected: "select id from user where database()", - // no bindvar needs - }, - { - in: "select table_name from information_schema.tables where table_schema = database()", - expected: "select table_name from information_schema.tables where table_schema = database()", - // no bindvar needs - }, - { - in: "select schema()", - expected: "select :__vtdbname as `schema()`", - db: true, - }, - { - in: "select found_rows()", - expected: "select :__vtfrows as `found_rows()`", - foundRows: true, - }, - { - in: "select @`x y`", - expected: "select :__vtudvx_y as `@``x y``` from dual", - udv: 1, - }, - { - in: "select id from t where id = @x and val = @y", - expected: "select id from t where id = :__vtudvx and val = :__vtudvy", - db: false, udv: 2, - }, - { - in: "insert into t(id) values(@xyx)", - expected: "insert into t(id) values(:__vtudvxyx)", - db: false, udv: 1, - }, - { - in: "select row_count()", - expected: "select :__vtrcount as `row_count()`", - rowCount: true, - }, - { - in: "SELECT lower(database())", - expected: "SELECT lower(:__vtdbname) as `lower(database())`", - db: true, - }, - } + tests := []myTestCase{{ + in: "SELECT 42", + expected: "SELECT 42", + // no bindvar needs + }, { + in: "SELECT last_insert_id()", + expected: "SELECT :__lastInsertId as `last_insert_id()`", + liid: true, + }, { + in: "SELECT database()", + expected: "SELECT :__vtdbname as `database()`", + db: true, + }, { + in: "SELECT database() from test", + expected: "SELECT database() from test", + // no bindvar needs + }, { + in: "SELECT last_insert_id() as test", + expected: "SELECT :__lastInsertId as test", + liid: true, + }, { + in: "SELECT last_insert_id() + database()", + expected: "SELECT :__lastInsertId + :__vtdbname as `last_insert_id() + database()`", + db: true, liid: true, + }, { + in: "select (select database()) from test", + expected: "select (select database() from dual) from test", + // no bindvar needs + }, { + in: "select (select database() from dual) from test", + expected: "select (select database() from dual) from test", + // no bindvar needs + }, { + in: "select (select database() from dual) from dual", + expected: "select (select :__vtdbname as `database()` from dual) as `(select database() from dual)` from dual", + db: true, + }, { + in: "select id from user where database()", + expected: "select id from user where database()", + // no bindvar needs + }, { + in: "select table_name from information_schema.tables where table_schema = database()", + expected: "select table_name from information_schema.tables where table_schema = database()", + // no bindvar needs + }, { + in: "select schema()", + expected: "select :__vtdbname as `schema()`", + db: true, + }, { + in: "select found_rows()", + expected: "select :__vtfrows as `found_rows()`", + foundRows: true, + }, { + in: "select @`x y`", + expected: "select :__vtudvx_y as `@``x y``` from dual", + udv: 1, + }, { + in: "select id from t where id = @x and val = @y", + expected: "select id from t where id = :__vtudvx and val = :__vtudvy", + db: false, udv: 2, + }, { + in: "insert into t(id) values(@xyx)", + expected: "insert into t(id) values(:__vtudvxyx)", + db: false, udv: 1, + }, { + in: "select row_count()", + expected: "select :__vtrcount as `row_count()`", + rowCount: true, + }, { + in: "SELECT lower(database())", + expected: "SELECT lower(:__vtdbname) as `lower(database())`", + db: true, + }, { + in: "SELECT * FROM A JOIN B USING (id)", + expected: "SELECT * FROM A JOIN B ON A.id = B.id", + }, { + in: "SELECT * FROM A JOIN B USING (id1,id2,id3)", + expected: "SELECT * FROM A JOIN B ON A.id1 = B.id1 AND A.id2 = B.id2 AND A.id3 = B.id3", + }} for _, tc := range tests { in.Run(tc.in, func(t *testing.T) { diff --git a/go/vt/vtgate/planbuilder/builder.go b/go/vt/vtgate/planbuilder/builder.go index 5f5dabc26da..8a972ba41c6 100644 --- a/go/vt/vtgate/planbuilder/builder.go +++ b/go/vt/vtgate/planbuilder/builder.go @@ -19,6 +19,8 @@ package planbuilder import ( "errors" + querypb "vitess.io/vitess/go/vt/proto/query" + "vitess.io/vitess/go/vt/vterrors" "vitess.io/vitess/go/vt/key" @@ -269,7 +271,7 @@ func Build(query string, vschema ContextVSchema) (*engine.Plan, error) { if err != nil { return nil, err } - result, err := sqlparser.RewriteAST(stmt) + result, err := sqlparser.PrepareAST(stmt, map[string]*querypb.BindVariable{}, "", false) if err != nil { return nil, err } diff --git a/go/vt/vtgate/planbuilder/testdata/from_cases.txt b/go/vt/vtgate/planbuilder/testdata/from_cases.txt index bdf90d1151b..e923ee3da67 100644 --- a/go/vt/vtgate/planbuilder/testdata/from_cases.txt +++ b/go/vt/vtgate/planbuilder/testdata/from_cases.txt @@ -2080,6 +2080,43 @@ } } +# join with USING construct +"select user.id from user join user_extra using(id)" +{ + "QueryType": "SELECT", + "Original": "select user.id from user join user_extra using(id)", + "Instructions": { + "OperatorType": "Join", + "Variant": "Join", + "JoinColumnIndexes": "-1", + "TableName": "user_user_extra", + "Inputs": [ + { + "OperatorType": "Route", + "Variant": "SelectScatter", + "Keyspace": { + "Name": "user", + "Sharded": true + }, + "FieldQuery": "select user.id from user where 1 != 1", + "Query": "select user.id from user", + "Table": "user" + }, + { + "OperatorType": "Route", + "Variant": "SelectScatter", + "Keyspace": { + "Name": "user", + "Sharded": true + }, + "FieldQuery": "select 1 from user_extra where 1 != 1", + "Query": "select 1 from user_extra where user_extra.id = :user_id", + "Table": "user_extra" + } + ] + } +} + # verify ',' vs JOIN precedence "select u1.a from unsharded u1, unsharded u2 join unsharded u3 on u1.a = u2.a" "symbol u1.a not found" diff --git a/go/vt/vtgate/planbuilder/testdata/unsupported_cases.txt b/go/vt/vtgate/planbuilder/testdata/unsupported_cases.txt index 7dd4584e1f5..342c7a2f681 100644 --- a/go/vt/vtgate/planbuilder/testdata/unsupported_cases.txt +++ b/go/vt/vtgate/planbuilder/testdata/unsupported_cases.txt @@ -54,10 +54,6 @@ "select * from user natural right join user_extra" "unsupported: natural right join" -# join with USING construct -"select * from user join user_extra using(id)" -"unsupported: join with USING(column_list) clause" - # left join with expressions "select user.id, user_extra.col+1 from user left join user_extra on user.col = user_extra.col" "unsupported: cross-shard left join and column expressions" From 6b51ce91fabf0bffdf8af571e57b1680604c5315 Mon Sep 17 00:00:00 2001 From: Andres Taylor Date: Thu, 3 Sep 2020 07:31:38 +0200 Subject: [PATCH 2/8] don't rewrite JOIN USING when there are * in the SELECT Signed-off-by: Andres Taylor --- go/vt/sqlparser/expression_rewriting.go | 10 +++++++++- go/vt/sqlparser/expression_rewriting_test.go | 10 +++++++--- .../vtgate/planbuilder/testdata/unsupported_cases.txt | 4 ++++ 3 files changed, 20 insertions(+), 4 deletions(-) diff --git a/go/vt/sqlparser/expression_rewriting.go b/go/vt/sqlparser/expression_rewriting.go index 07895faf3e9..a7c0efef28f 100644 --- a/go/vt/sqlparser/expression_rewriting.go +++ b/go/vt/sqlparser/expression_rewriting.go @@ -104,6 +104,9 @@ type expressionRewriter struct { bindVars map[string]struct{} shouldRewriteDatabaseFunc bool err error + + // we need to know this to make a decision if we can safely rewrite JOIN USING => JOIN ON + hasStarInSelect bool } func newExpressionRewriter() *expressionRewriter { @@ -132,6 +135,11 @@ func (er *expressionRewriter) goingDown(cursor *Cursor) bool { // select last_insert_id() -> select :__lastInsertId as `last_insert_id()` case *Select: for _, col := range node.SelectExprs { + _, hasStar := col.(*StarExpr) + if hasStar { + er.hasStarInSelect = true + } + aliasedExpr, ok := col.(*AliasedExpr) if ok && aliasedExpr.As.IsEmpty() { buf := NewTrackedBuffer(nil) @@ -163,7 +171,7 @@ func (er *expressionRewriter) goingDown(cursor *Cursor) bool { } case JoinCondition: - if node.Using != nil { + if node.Using != nil && !er.hasStarInSelect { joinTableExpr, ok := cursor.Parent().(*JoinTableExpr) if !ok { // this is not possible with the current AST diff --git a/go/vt/sqlparser/expression_rewriting_test.go b/go/vt/sqlparser/expression_rewriting_test.go index fe6d90623d0..bf4bf2bfeb4 100644 --- a/go/vt/sqlparser/expression_rewriting_test.go +++ b/go/vt/sqlparser/expression_rewriting_test.go @@ -102,11 +102,15 @@ func TestRewrites(in *testing.T) { expected: "SELECT lower(:__vtdbname) as `lower(database())`", db: true, }, { - in: "SELECT * FROM A JOIN B USING (id)", - expected: "SELECT * FROM A JOIN B ON A.id = B.id", + in: "SELECT a.col, b.col FROM A JOIN B USING (id)", + expected: "SELECT a.col, b.col FROM A JOIN B ON A.id = B.id", }, { + in: "SELECT a.col, b.col FROM A JOIN B USING (id1,id2,id3)", + expected: "SELECT a.col, b.col FROM A JOIN B ON A.id1 = B.id1 AND A.id2 = B.id2 AND A.id3 = B.id3", + }, { + // SELECT * behaves different depending the join type used, so if that has been used, we won't rewrite in: "SELECT * FROM A JOIN B USING (id1,id2,id3)", - expected: "SELECT * FROM A JOIN B ON A.id1 = B.id1 AND A.id2 = B.id2 AND A.id3 = B.id3", + expected: "SELECT * FROM A JOIN B USING (id1,id2,id3)", }} for _, tc := range tests { diff --git a/go/vt/vtgate/planbuilder/testdata/unsupported_cases.txt b/go/vt/vtgate/planbuilder/testdata/unsupported_cases.txt index 342c7a2f681..445a76958d9 100644 --- a/go/vt/vtgate/planbuilder/testdata/unsupported_cases.txt +++ b/go/vt/vtgate/planbuilder/testdata/unsupported_cases.txt @@ -46,6 +46,10 @@ "select * from user natural join user_extra" "unsupported: natural join" +# join with USING construct +"select * from user join user_extra using(id)" +"unsupported: join with USING(column_list) clause" + # natural left join "select * from user natural left join user_extra" "unsupported: natural left join" From 8330d4e443b911d0d4d13d795b3d16c256909044 Mon Sep 17 00:00:00 2001 From: GuptaManan100 Date: Tue, 8 Dec 2020 13:19:35 +0530 Subject: [PATCH 3/8] Refactored code and added tests Signed-off-by: GuptaManan100 --- go/vt/sqlparser/ast_rewriting.go | 232 ++++++++++++++++ ...ewriting_test.go => ast_rewriting_test.go} | 0 go/vt/sqlparser/expression_rewriting.go | 254 ------------------ go/vt/vtgate/planbuilder/join.go | 2 +- .../testdata/unsupported_cases.txt | 6 +- 5 files changed, 238 insertions(+), 256 deletions(-) rename go/vt/sqlparser/{expression_rewriting_test.go => ast_rewriting_test.go} (100%) delete mode 100644 go/vt/sqlparser/expression_rewriting.go diff --git a/go/vt/sqlparser/ast_rewriting.go b/go/vt/sqlparser/ast_rewriting.go index ef8a2580273..f78bf8bcb32 100644 --- a/go/vt/sqlparser/ast_rewriting.go +++ b/go/vt/sqlparser/ast_rewriting.go @@ -20,6 +20,10 @@ import ( querypb "vitess.io/vitess/go/vt/proto/query" vtrpcpb "vitess.io/vitess/go/vt/proto/vtrpc" "vitess.io/vitess/go/vt/vterrors" + + "strings" + + "vitess.io/vitess/go/vt/sysvars" ) // RewriteASTResult contains the rewritten ast and meta information about it @@ -55,3 +59,231 @@ func RewriteAST(in Statement) (*RewriteASTResult, error) { } return r, nil } + +func shouldRewriteDatabaseFunc(in Statement) bool { + selct, ok := in.(*Select) + if !ok { + return false + } + if len(selct.From) != 1 { + return false + } + aliasedTable, ok := selct.From[0].(*AliasedTableExpr) + if !ok { + return false + } + tableName, ok := aliasedTable.Expr.(TableName) + if !ok { + return false + } + return tableName.Name.String() == "dual" +} + +type expressionRewriter struct { + bindVars *BindVarNeeds + shouldRewriteDatabaseFunc bool + err error + + // we need to know this to make a decision if we can safely rewrite JOIN USING => JOIN ON + hasStarInSelect bool +} + +func newExpressionRewriter() *expressionRewriter { + return &expressionRewriter{bindVars: &BindVarNeeds{}} +} + +const ( + //LastInsertIDName is a reserved bind var name for last_insert_id() + LastInsertIDName = "__lastInsertId" + + //DBVarName is a reserved bind var name for database() + DBVarName = "__vtdbname" + + //FoundRowsName is a reserved bind var name for found_rows() + FoundRowsName = "__vtfrows" + + //RowCountName is a reserved bind var name for row_count() + RowCountName = "__vtrcount" + + //UserDefinedVariableName is what we prepend bind var names for user defined variables + UserDefinedVariableName = "__vtudv" +) + +func (er *expressionRewriter) rewriteAliasedExpr(node *AliasedExpr) (*BindVarNeeds, error) { + inner := newExpressionRewriter() + inner.shouldRewriteDatabaseFunc = er.shouldRewriteDatabaseFunc + tmp := Rewrite(node.Expr, inner.rewrite, nil) + newExpr, ok := tmp.(Expr) + if !ok { + return nil, vterrors.Errorf(vtrpcpb.Code_INTERNAL, "failed to rewrite AST. function expected to return Expr returned a %s", String(tmp)) + } + node.Expr = newExpr + return inner.bindVars, nil +} + +func (er *expressionRewriter) rewrite(cursor *Cursor) bool { + switch node := cursor.Node().(type) { + // select last_insert_id() -> select :__lastInsertId as `last_insert_id()` + case *Select: + for _, col := range node.SelectExprs { + _, hasStar := col.(*StarExpr) + if hasStar { + er.hasStarInSelect = true + } + + aliasedExpr, ok := col.(*AliasedExpr) + if ok && aliasedExpr.As.IsEmpty() { + buf := NewTrackedBuffer(nil) + aliasedExpr.Expr.Format(buf) + innerBindVarNeeds, err := er.rewriteAliasedExpr(aliasedExpr) + if err != nil { + er.err = err + return false + } + if innerBindVarNeeds.HasRewrites() { + aliasedExpr.As = NewColIdent(buf.String()) + } + er.bindVars.MergeWith(innerBindVarNeeds) + } + } + case *FuncExpr: + er.funcRewrite(cursor, node) + case *ColName: + switch node.Name.at { + case SingleAt: + er.udvRewrite(cursor, node) + case DoubleAt: + er.sysVarRewrite(cursor, node) + } + case *Subquery: + er.unnestSubQueries(cursor, node) + + case JoinCondition: + if node.Using != nil && !er.hasStarInSelect { + joinTableExpr, ok := cursor.Parent().(*JoinTableExpr) + if !ok { + // this is not possible with the current AST + break + } + leftTable, leftOk := joinTableExpr.LeftExpr.(*AliasedTableExpr) + rightTable, rightOk := joinTableExpr.RightExpr.(*AliasedTableExpr) + if !(leftOk && rightOk) { + // we only deal with simple FROM A JOIN B USING queries at the moment + break + } + lft, err := leftTable.TableName() + if err != nil { + er.err = err + break + } + rgt, err := rightTable.TableName() + if err != nil { + er.err = err + break + } + newCondition := JoinCondition{} + for _, colIdent := range node.Using { + lftCol := NewColNameWithQualifier(colIdent.String(), lft) + rgtCol := NewColNameWithQualifier(colIdent.String(), rgt) + cmp := &ComparisonExpr{ + Operator: EqualOp, + Left: lftCol, + Right: rgtCol, + } + if newCondition.On == nil { + newCondition.On = cmp + } else { + newCondition.On = &AndExpr{Left: newCondition.On, Right: cmp} + } + } + cursor.Replace(newCondition) + } + + } + return true +} + +func (er *expressionRewriter) sysVarRewrite(cursor *Cursor, node *ColName) { + lowered := node.Name.Lowered() + switch lowered { + case sysvars.Autocommit.Name, + sysvars.ClientFoundRows.Name, + sysvars.SkipQueryPlanCache.Name, + sysvars.SQLSelectLimit.Name, + sysvars.TransactionMode.Name, + sysvars.Workload.Name, + sysvars.DDLStrategy.Name, + sysvars.ReadAfterWriteGTID.Name, + sysvars.ReadAfterWriteTimeOut.Name, + sysvars.SessionTrackGTIDs.Name: + cursor.Replace(bindVarExpression("__vt" + lowered)) + er.bindVars.AddSysVar(lowered) + } +} + +func (er *expressionRewriter) udvRewrite(cursor *Cursor, node *ColName) { + udv := strings.ToLower(node.Name.CompliantName()) + cursor.Replace(bindVarExpression(UserDefinedVariableName + udv)) + er.bindVars.AddUserDefVar(udv) +} + +var funcRewrites = map[string]string{ + "last_insert_id": LastInsertIDName, + "database": DBVarName, + "schema": DBVarName, + "found_rows": FoundRowsName, + "row_count": RowCountName, +} + +func (er *expressionRewriter) funcRewrite(cursor *Cursor, node *FuncExpr) { + bindVar, found := funcRewrites[node.Name.Lowered()] + if found { + if bindVar == DBVarName && !er.shouldRewriteDatabaseFunc { + return + } + if len(node.Exprs) > 0 { + er.err = vterrors.Errorf(vtrpcpb.Code_UNIMPLEMENTED, "Argument to %s() not supported", node.Name.Lowered()) + return + } + cursor.Replace(bindVarExpression(bindVar)) + er.bindVars.AddFuncResult(bindVar) + } +} + +func (er *expressionRewriter) unnestSubQueries(cursor *Cursor, subquery *Subquery) { + sel, isSimpleSelect := subquery.Select.(*Select) + if !isSimpleSelect { + return + } + + if !(len(sel.SelectExprs) != 1 || + len(sel.OrderBy) != 0 || + len(sel.GroupBy) != 0 || + len(sel.From) != 1 || + sel.Where == nil || + sel.Having == nil || + sel.Limit == nil) && sel.Lock == NoLock { + return + } + aliasedTable, ok := sel.From[0].(*AliasedTableExpr) + if !ok { + return + } + table, ok := aliasedTable.Expr.(TableName) + if !ok || table.Name.String() != "dual" { + return + } + expr, ok := sel.SelectExprs[0].(*AliasedExpr) + if !ok { + return + } + er.bindVars.NoteRewrite() + // we need to make sure that the inner expression also gets rewritten, + // so we fire off another rewriter traversal here + rewrittenExpr := Rewrite(expr.Expr, er.rewrite, nil) + cursor.Replace(rewrittenExpr) +} + +func bindVarExpression(name string) Expr { + return NewArgument([]byte(":" + name)) +} diff --git a/go/vt/sqlparser/expression_rewriting_test.go b/go/vt/sqlparser/ast_rewriting_test.go similarity index 100% rename from go/vt/sqlparser/expression_rewriting_test.go rename to go/vt/sqlparser/ast_rewriting_test.go diff --git a/go/vt/sqlparser/expression_rewriting.go b/go/vt/sqlparser/expression_rewriting.go deleted file mode 100644 index a619e13c070..00000000000 --- a/go/vt/sqlparser/expression_rewriting.go +++ /dev/null @@ -1,254 +0,0 @@ -/* -Copyright 2019 The Vitess Authors. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -*/ - -package sqlparser - -import ( - "strings" - - "vitess.io/vitess/go/vt/sysvars" - - "vitess.io/vitess/go/vt/proto/vtrpc" - "vitess.io/vitess/go/vt/vterrors" -) - -func shouldRewriteDatabaseFunc(in Statement) bool { - selct, ok := in.(*Select) - if !ok { - return false - } - if len(selct.From) != 1 { - return false - } - aliasedTable, ok := selct.From[0].(*AliasedTableExpr) - if !ok { - return false - } - tableName, ok := aliasedTable.Expr.(TableName) - if !ok { - return false - } - return tableName.Name.String() == "dual" -} - -type expressionRewriter struct { - bindVars *BindVarNeeds - shouldRewriteDatabaseFunc bool - err error - - // we need to know this to make a decision if we can safely rewrite JOIN USING => JOIN ON - hasStarInSelect bool -} - -func newExpressionRewriter() *expressionRewriter { - return &expressionRewriter{bindVars: &BindVarNeeds{}} -} - -const ( - //LastInsertIDName is a reserved bind var name for last_insert_id() - LastInsertIDName = "__lastInsertId" - - //DBVarName is a reserved bind var name for database() - DBVarName = "__vtdbname" - - //FoundRowsName is a reserved bind var name for found_rows() - FoundRowsName = "__vtfrows" - - //RowCountName is a reserved bind var name for row_count() - RowCountName = "__vtrcount" - - //UserDefinedVariableName is what we prepend bind var names for user defined variables - UserDefinedVariableName = "__vtudv" -) - -func (er *expressionRewriter) rewriteAliasedExpr(node *AliasedExpr) (*BindVarNeeds, error) { - inner := newExpressionRewriter() - inner.shouldRewriteDatabaseFunc = er.shouldRewriteDatabaseFunc - tmp := Rewrite(node.Expr, inner.rewrite, nil) - newExpr, ok := tmp.(Expr) - if !ok { - return nil, vterrors.Errorf(vtrpc.Code_INTERNAL, "failed to rewrite AST. function expected to return Expr returned a %s", String(tmp)) - } - node.Expr = newExpr - return inner.bindVars, nil -} - -func (er *expressionRewriter) rewrite(cursor *Cursor) bool { - switch node := cursor.Node().(type) { - // select last_insert_id() -> select :__lastInsertId as `last_insert_id()` - case *Select: - for _, col := range node.SelectExprs { - _, hasStar := col.(*StarExpr) - if hasStar { - er.hasStarInSelect = true - } - - aliasedExpr, ok := col.(*AliasedExpr) - if ok && aliasedExpr.As.IsEmpty() { - buf := NewTrackedBuffer(nil) - aliasedExpr.Expr.Format(buf) - innerBindVarNeeds, err := er.rewriteAliasedExpr(aliasedExpr) - if err != nil { - er.err = err - return false - } - if innerBindVarNeeds.HasRewrites() { - aliasedExpr.As = NewColIdent(buf.String()) - } - er.bindVars.MergeWith(innerBindVarNeeds) - } - } - case *FuncExpr: - er.funcRewrite(cursor, node) - case *ColName: - switch node.Name.at { - case SingleAt: - er.udvRewrite(cursor, node) - case DoubleAt: - er.sysVarRewrite(cursor, node) - } - case *Subquery: - er.unnestSubQueries(cursor, node) - - case JoinCondition: - if node.Using != nil && !er.hasStarInSelect { - joinTableExpr, ok := cursor.Parent().(*JoinTableExpr) - if !ok { - // this is not possible with the current AST - break - } - leftTable, leftOk := joinTableExpr.LeftExpr.(*AliasedTableExpr) - rightTable, rightOk := joinTableExpr.RightExpr.(*AliasedTableExpr) - if !(leftOk && rightOk) { - // we only deal with simple FROM A JOIN B USING queries at the moment - break - } - lft, err := leftTable.TableName() - if err != nil { - er.err = err - break - } - rgt, err := rightTable.TableName() - if err != nil { - er.err = err - break - } - newCondition := JoinCondition{} - for _, colIdent := range node.Using { - lftCol := NewColNameWithQualifier(colIdent.String(), lft) - rgtCol := NewColNameWithQualifier(colIdent.String(), rgt) - cmp := &ComparisonExpr{ - Operator: EqualOp, - Left: lftCol, - Right: rgtCol, - } - if newCondition.On == nil { - newCondition.On = cmp - } else { - newCondition.On = &AndExpr{Left: newCondition.On, Right: cmp} - } - } - cursor.Replace(newCondition) - } - - } - return true -} - -func (er *expressionRewriter) sysVarRewrite(cursor *Cursor, node *ColName) { - lowered := node.Name.Lowered() - switch lowered { - case sysvars.Autocommit.Name, - sysvars.ClientFoundRows.Name, - sysvars.SkipQueryPlanCache.Name, - sysvars.SQLSelectLimit.Name, - sysvars.TransactionMode.Name, - sysvars.Workload.Name, - sysvars.DDLStrategy.Name, - sysvars.ReadAfterWriteGTID.Name, - sysvars.ReadAfterWriteTimeOut.Name, - sysvars.SessionTrackGTIDs.Name: - cursor.Replace(bindVarExpression("__vt" + lowered)) - er.bindVars.AddSysVar(lowered) - } -} - -func (er *expressionRewriter) udvRewrite(cursor *Cursor, node *ColName) { - udv := strings.ToLower(node.Name.CompliantName()) - cursor.Replace(bindVarExpression(UserDefinedVariableName + udv)) - er.bindVars.AddUserDefVar(udv) -} - -var funcRewrites = map[string]string{ - "last_insert_id": LastInsertIDName, - "database": DBVarName, - "schema": DBVarName, - "found_rows": FoundRowsName, - "row_count": RowCountName, -} - -func (er *expressionRewriter) funcRewrite(cursor *Cursor, node *FuncExpr) { - bindVar, found := funcRewrites[node.Name.Lowered()] - if found { - if bindVar == DBVarName && !er.shouldRewriteDatabaseFunc { - return - } - if len(node.Exprs) > 0 { - er.err = vterrors.Errorf(vtrpc.Code_UNIMPLEMENTED, "Argument to %s() not supported", node.Name.Lowered()) - return - } - cursor.Replace(bindVarExpression(bindVar)) - er.bindVars.AddFuncResult(bindVar) - } -} - -func (er *expressionRewriter) unnestSubQueries(cursor *Cursor, subquery *Subquery) { - sel, isSimpleSelect := subquery.Select.(*Select) - if !isSimpleSelect { - return - } - - if !(len(sel.SelectExprs) != 1 || - len(sel.OrderBy) != 0 || - len(sel.GroupBy) != 0 || - len(sel.From) != 1 || - sel.Where == nil || - sel.Having == nil || - sel.Limit == nil) && sel.Lock == NoLock { - return - } - aliasedTable, ok := sel.From[0].(*AliasedTableExpr) - if !ok { - return - } - table, ok := aliasedTable.Expr.(TableName) - if !ok || table.Name.String() != "dual" { - return - } - expr, ok := sel.SelectExprs[0].(*AliasedExpr) - if !ok { - return - } - er.bindVars.NoteRewrite() - // we need to make sure that the inner expression also gets rewritten, - // so we fire off another rewriter traversal here - rewrittenExpr := Rewrite(expr.Expr, er.rewrite, nil) - cursor.Replace(rewrittenExpr) -} - -func bindVarExpression(name string) Expr { - return NewArgument([]byte(":" + name)) -} diff --git a/go/vt/vtgate/planbuilder/join.go b/go/vt/vtgate/planbuilder/join.go index 71b9e74c1ee..dc4ffe26732 100644 --- a/go/vt/vtgate/planbuilder/join.go +++ b/go/vt/vtgate/planbuilder/join.go @@ -98,7 +98,7 @@ func newJoin(lpb, rpb *primitiveBuilder, ajoin *sqlparser.JoinTableExpr) error { return err } case ajoin.Condition.Using != nil: - return errors.New("unsupported: join with USING(column_list) clause") + return errors.New("unsupported: join with USING(column_list) clause for complex queries") } } lpb.plan = &join{ diff --git a/go/vt/vtgate/planbuilder/testdata/unsupported_cases.txt b/go/vt/vtgate/planbuilder/testdata/unsupported_cases.txt index 6dc080e78c1..31593d3c99d 100644 --- a/go/vt/vtgate/planbuilder/testdata/unsupported_cases.txt +++ b/go/vt/vtgate/planbuilder/testdata/unsupported_cases.txt @@ -44,7 +44,11 @@ # join with USING construct "select * from user join user_extra using(id)" -"unsupported: join with USING(column_list) clause" +"unsupported: join with USING(column_list) clause for complex queries" + +# join with USING construct with 3 tables +"select user.id from user join user_extra using(id) join music using(id2)" +"unsupported: join with USING(column_list) clause for complex queries" # natural left join "select * from user natural left join user_extra" From 94aa0a70a65034e54324027b22d5dffdf5873d65 Mon Sep 17 00:00:00 2001 From: GuptaManan100 Date: Tue, 8 Dec 2020 13:41:03 +0530 Subject: [PATCH 4/8] join SelectDBA queries on the vtgate level Signed-off-by: GuptaManan100 --- go/vt/vtgate/planbuilder/route.go | 4 +- .../planbuilder/testdata/from_cases.txt | 73 ++++++++++++++++--- 2 files changed, 65 insertions(+), 12 deletions(-) diff --git a/go/vt/vtgate/planbuilder/route.go b/go/vt/vtgate/planbuilder/route.go index f99e4360f4c..e6eab2b8337 100644 --- a/go/vt/vtgate/planbuilder/route.go +++ b/go/vt/vtgate/planbuilder/route.go @@ -377,7 +377,7 @@ func (rb *route) JoinCanMerge(pb *primitiveBuilder, rrb *route, ajoin *sqlparser return true } switch rb.eroute.Opcode { - case engine.SelectUnsharded, engine.SelectDBA: + case engine.SelectUnsharded: return rb.eroute.Opcode == rrb.eroute.Opcode case engine.SelectEqualUnique: // Check if they target the same shard. @@ -386,7 +386,7 @@ func (rb *route) JoinCanMerge(pb *primitiveBuilder, rrb *route, ajoin *sqlparser } case engine.SelectReference: return true - case engine.SelectNext: + case engine.SelectNext, engine.SelectDBA: return false } if ajoin == nil { diff --git a/go/vt/vtgate/planbuilder/testdata/from_cases.txt b/go/vt/vtgate/planbuilder/testdata/from_cases.txt index ca27e5041cc..5d9c12c491d 100644 --- a/go/vt/vtgate/planbuilder/testdata/from_cases.txt +++ b/go/vt/vtgate/planbuilder/testdata/from_cases.txt @@ -296,19 +296,37 @@ } # ',' join information_schema -"select * from information_schema.a, information_schema.b" +"select a.id,b.id from information_schema.a as a, information_schema.b as b" { "QueryType": "SELECT", - "Original": "select * from information_schema.a, information_schema.b", + "Original": "select a.id,b.id from information_schema.a as a, information_schema.b as b", "Instructions": { - "OperatorType": "Route", - "Variant": "SelectDBA", - "Keyspace": { - "Name": "main", - "Sharded": false - }, - "FieldQuery": "select * from information_schema.a, information_schema.b where 1 != 1", - "Query": "select * from information_schema.a, information_schema.b" + "OperatorType": "Join", + "Variant": "Join", + "JoinColumnIndexes": "-1,1", + "TableName": "_", + "Inputs": [ + { + "OperatorType": "Route", + "Variant": "SelectDBA", + "Keyspace": { + "Name": "main", + "Sharded": false + }, + "FieldQuery": "select a.id from information_schema.a as a where 1 != 1", + "Query": "select a.id from information_schema.a as a" + }, + { + "OperatorType": "Route", + "Variant": "SelectDBA", + "Keyspace": { + "Name": "main", + "Sharded": false + }, + "FieldQuery": "select b.id from information_schema.b as b where 1 != 1", + "Query": "select b.id from information_schema.b as b" + } + ] } } @@ -2131,3 +2149,38 @@ "Query": "select column_name from information_schema.`columns` where table_schema = schema()" } } + +# information schema join +"select 42 from information_schema.a join information_schema.b" +{ + "QueryType": "SELECT", + "Original": "select 42 from information_schema.a join information_schema.b", + "Instructions": { + "OperatorType": "Join", + "Variant": "Join", + "JoinColumnIndexes": "-1", + "TableName": "_", + "Inputs": [ + { + "OperatorType": "Route", + "Variant": "SelectDBA", + "Keyspace": { + "Name": "main", + "Sharded": false + }, + "FieldQuery": "select 42 from information_schema.a where 1 != 1", + "Query": "select 42 from information_schema.a" + }, + { + "OperatorType": "Route", + "Variant": "SelectDBA", + "Keyspace": { + "Name": "main", + "Sharded": false + }, + "FieldQuery": "select 1 from information_schema.b where 1 != 1", + "Query": "select 1 from information_schema.b" + } + ] + } +} From 5bae863d0a5c146954851d6634db67f0d1122331 Mon Sep 17 00:00:00 2001 From: GuptaManan100 Date: Tue, 8 Dec 2020 14:36:10 +0530 Subject: [PATCH 5/8] Fixed test in join merging Signed-off-by: GuptaManan100 --- go/vt/vtgate/planbuilder/route_test.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/go/vt/vtgate/planbuilder/route_test.go b/go/vt/vtgate/planbuilder/route_test.go index ca133f946ba..cf60f09e28f 100644 --- a/go/vt/vtgate/planbuilder/route_test.go +++ b/go/vt/vtgate/planbuilder/route_test.go @@ -50,7 +50,7 @@ func TestJoinCanMerge(t *testing.T) { {false, false, false, false, false, false, false, false, true, false}, {false, false, false, false, false, false, false, false, true, false}, {false, false, false, false, false, false, false, false, true, false}, - {false, false, false, false, false, false, false, true, true, false}, + {false, false, false, false, false, false, false, false, true, false}, {true, true, true, true, true, true, true, true, true, true}, {false, false, false, false, false, false, false, false, true, false}, } From 5c563c418d0034d798ad62382541688d25dbb6c7 Mon Sep 17 00:00:00 2001 From: GuptaManan100 Date: Tue, 8 Dec 2020 14:40:31 +0530 Subject: [PATCH 6/8] cleanup Signed-off-by: GuptaManan100 --- go/vt/vtgate/planbuilder/builder.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/go/vt/vtgate/planbuilder/builder.go b/go/vt/vtgate/planbuilder/builder.go index 6646a4873e6..64e50b3c6e3 100644 --- a/go/vt/vtgate/planbuilder/builder.go +++ b/go/vt/vtgate/planbuilder/builder.go @@ -60,7 +60,7 @@ func Build(query string, vschema ContextVSchema) (*engine.Plan, error) { if err != nil { return nil, err } - result, err := sqlparser.PrepareAST(stmt, map[string]*querypb.BindVariable{}, "", false) + result, err := sqlparser.RewriteAST(stmt) if err != nil { return nil, err } From 82e65653e322781ff5223ab64af0bedef7fd135f Mon Sep 17 00:00:00 2001 From: GuptaManan100 Date: Tue, 8 Dec 2020 14:54:11 +0530 Subject: [PATCH 7/8] Added test for rails query Signed-off-by: GuptaManan100 --- .../planbuilder/testdata/from_cases.txt | 37 +++++++++++++++++++ 1 file changed, 37 insertions(+) diff --git a/go/vt/vtgate/planbuilder/testdata/from_cases.txt b/go/vt/vtgate/planbuilder/testdata/from_cases.txt index e2d957efaa2..b85e65681f4 100644 --- a/go/vt/vtgate/planbuilder/testdata/from_cases.txt +++ b/go/vt/vtgate/planbuilder/testdata/from_cases.txt @@ -2221,3 +2221,40 @@ ] } } + +# rails query +"select fk.referenced_table_name as to_table, fk.referenced_column_name as primary_key, fk.column_name as `column`, fk.constraint_name as name, rc.update_rule as on_update, rc.delete_rule as on_delete from information_schema.referential_constraints as rc join information_schema.key_column_usage as fk using (constraint_schema, constraint_name) where fk.referenced_column_name is not null and fk.table_schema = database() and fk.table_name = ':vtg1' and rc.constraint_schema = database() and rc.table_name = ':vtg1'" +{ + "QueryType": "SELECT", + "Original": "select fk.referenced_table_name as to_table, fk.referenced_column_name as primary_key, fk.column_name as `column`, fk.constraint_name as name, rc.update_rule as on_update, rc.delete_rule as on_delete from information_schema.referential_constraints as rc join information_schema.key_column_usage as fk using (constraint_schema, constraint_name) where fk.referenced_column_name is not null and fk.table_schema = database() and fk.table_name = ':vtg1' and rc.constraint_schema = database() and rc.table_name = ':vtg1'", + "Instructions": { + "OperatorType": "Join", + "Variant": "Join", + "JoinColumnIndexes": "1,2,3,4,-1,-2", + "TableName": "_", + "Inputs": [ + { + "OperatorType": "Route", + "Variant": "SelectDBA", + "Keyspace": { + "Name": "main", + "Sharded": false + }, + "FieldQuery": "select rc.update_rule as on_update, rc.delete_rule as on_delete, rc.constraint_schema, rc.constraint_name from information_schema.referential_constraints as rc where 1 != 1", + "Query": "select rc.update_rule as on_update, rc.delete_rule as on_delete, rc.constraint_schema, rc.constraint_name from information_schema.referential_constraints as rc where rc.constraint_schema = database() and rc.table_name = :__vttablename", + "SysTableTableName": "VARBINARY(\":vtg1\")" + }, + { + "OperatorType": "Route", + "Variant": "SelectDBA", + "Keyspace": { + "Name": "main", + "Sharded": false + }, + "FieldQuery": "select fk.referenced_table_name as to_table, fk.referenced_column_name as primary_key, fk.column_name as `column`, fk.constraint_name as `name` from information_schema.key_column_usage as fk where 1 != 1", + "Query": "select fk.referenced_table_name as to_table, fk.referenced_column_name as primary_key, fk.column_name as `column`, fk.constraint_name as `name` from information_schema.key_column_usage as fk where fk.constraint_schema = :rc_constraint_schema and fk.constraint_name = :rc_constraint_name and fk.referenced_column_name is not null and fk.table_schema = database() and fk.table_name = :__vttablename", + "SysTableTableName": "VARBINARY(\":vtg1\")" + } + ] + } +} From 044de244910b5ac0d4129a56719b8cd90fce4193 Mon Sep 17 00:00:00 2001 From: Harshit Gangal Date: Tue, 8 Dec 2020 16:15:52 +0530 Subject: [PATCH 8/8] added e2e test for join with using in information_schema Signed-off-by: Harshit Gangal --- .../vtgate/information_schema_test.go | 11 +++++++++++ go/test/endtoend/vtgate/main_test.go | 19 ++++++++++++++++++- 2 files changed, 29 insertions(+), 1 deletion(-) diff --git a/go/test/endtoend/vtgate/information_schema_test.go b/go/test/endtoend/vtgate/information_schema_test.go index ada2844a8e5..4116efaf086 100644 --- a/go/test/endtoend/vtgate/information_schema_test.go +++ b/go/test/endtoend/vtgate/information_schema_test.go @@ -114,3 +114,14 @@ func TestInformationSchemaQueryGetsRoutedToTheRightTableAndKeyspace(t *testing.T result := exec(t, conn, "SELECT * FROM information_schema.tables WHERE table_schema = database() and table_name='t1000'") assert.NotEmpty(t, result.Rows) } + +func TestFKConstraintUsingInformationSchema(t *testing.T) { + defer cluster.PanicHandler(t) + ctx := context.Background() + conn, err := mysql.Connect(ctx, &vtParams) + require.NoError(t, err) + defer conn.Close() + + query := "select fk.referenced_table_name as to_table, fk.referenced_column_name as primary_key, fk.column_name as `column`, fk.constraint_name as name, rc.update_rule as on_update, rc.delete_rule as on_delete from information_schema.referential_constraints as rc join information_schema.key_column_usage as fk using (constraint_schema, constraint_name) where fk.referenced_column_name is not null and fk.table_schema = database() and fk.table_name = 't7_fk' and rc.constraint_schema = database() and rc.table_name = 't7_fk'" + assertMatches(t, conn, query, `[[VARCHAR("t7_xxhash") VARCHAR("uid") VARCHAR("t7_uid") VARCHAR("t7_fk_ibfk_1") VARCHAR("CASCADE") VARCHAR("SET NULL")]]`) +} diff --git a/go/test/endtoend/vtgate/main_test.go b/go/test/endtoend/vtgate/main_test.go index a9a2ec641b9..5f4422984e9 100644 --- a/go/test/endtoend/vtgate/main_test.go +++ b/go/test/endtoend/vtgate/main_test.go @@ -127,7 +127,16 @@ create table t7_xxhash_idx( phone bigint, keyspace_id varbinary(50), primary key(phone, keyspace_id) -) Engine=InnoDB;` +) Engine=InnoDB; + +create table t7_fk( + id bigint, + t7_uid varchar(50), + primary key(id), + CONSTRAINT t7_fk_ibfk_1 foreign key (t7_uid) references t7_xxhash(uid) + on delete set null on update cascade +) Engine=InnoDB; +` VSchema = ` { @@ -353,6 +362,14 @@ create table t7_xxhash_idx( "name": "unicode_loose_xxhash" } ] + }, + "t7_fk": { + "column_vindexes": [ + { + "column": "t7_uid", + "name": "unicode_loose_xxhash" + } + ] } } }`