diff --git a/go.sum b/go.sum index c907a93e5d9..137f60c3e6e 100644 --- a/go.sum +++ b/go.sum @@ -684,6 +684,7 @@ golang.org/x/crypto v0.0.0-20191206172530-e9b2fee46413 h1:ULYEB3JvPRE/IfO+9uO7vK golang.org/x/crypto v0.0.0-20191206172530-e9b2fee46413/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto= golang.org/x/crypto v0.0.0-20200220183623-bac4c82f6975 h1:/Tl7pH94bvbAAHBdZJT947M/+gp0+CqQXDtMRC0fseo= golang.org/x/crypto v0.0.0-20200220183623-bac4c82f6975/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto= +golang.org/x/crypto v0.0.0-20200622213623-75b288015ac9 h1:psW17arqaxU48Z5kZ0CQnkZWQJsqcURM6tKiBApRjXI= golang.org/x/crypto v0.0.0-20200622213623-75b288015ac9/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto= golang.org/x/exp v0.0.0-20190121172915-509febef88a4/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA= golang.org/x/exp v0.0.0-20190125153040-c74c464bbbf2/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA= @@ -733,6 +734,7 @@ golang.org/x/net v0.0.0-20200202094626-16171245cfb2/go.mod h1:z5CRVTTTmAJ677TzLL golang.org/x/net v0.0.0-20200324143707-d3edc9973b7e h1:3G+cUijn7XD+S4eJFddp53Pv7+slrESplyjG25HgL+k= golang.org/x/net v0.0.0-20200324143707-d3edc9973b7e/go.mod h1:qpuaurCH72eLCgpAm/N6yyVIVM9cpaDIP3A8BGJEC5A= golang.org/x/net v0.0.0-20200822124328-c89045814202 h1:VvcQYSHwXgi7W+TpUR6A9g6Up98WAHf3f/ulnJ62IyA= +golang.org/x/net v0.0.0-20201021035429-f5854403a974 h1:IX6qOQeG5uLjB/hjjwjedwfjND0hgjPMMyO1RoIXQNI= golang.org/x/net v0.0.0-20201021035429-f5854403a974/go.mod h1:sp8m0HH+o8qH0wwXwYZr8TS3Oi6o0r6Gce1SSxlDquU= golang.org/x/oauth2 v0.0.0-20180821212333-d2e6202438be/go.mod h1:N/0e6XlmueqKjAGxoOufVs8QHGRruUQn6yWY3a++T0U= golang.org/x/oauth2 v0.0.0-20190226205417-e64efc72b421/go.mod h1:gOpvHmFTYa4IltrdGE7lF6nIHvwfUNPOp7c8zoXwtLw= @@ -745,6 +747,7 @@ golang.org/x/sync v0.0.0-20190227155943-e225da77a7e6/go.mod h1:RxMgew5VJxzue5/jJ golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20190911185100-cd5d95a43a6e h1:vcxGaoTs7kV8m5Np9uUNQin4BrLOthgV7252N8V+FwY= golang.org/x/sync v0.0.0-20190911185100-cd5d95a43a6e/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= +golang.org/x/sync v0.0.0-20201020160332-67f06af15bc9 h1:SQFwaSi55rU7vdNs9Yr0Z324VNlrF+0wMqRXT4St8ck= golang.org/x/sync v0.0.0-20201020160332-67f06af15bc9/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sys v0.0.0-20170830134202-bb24a47a89ea/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20180823144017-11551d06cbcc/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= @@ -780,12 +783,14 @@ golang.org/x/sys v0.0.0-20200124204421-9fbb57f87de9/go.mod h1:h1NjWce9XRLGQEsW7w golang.org/x/sys v0.0.0-20200223170610-d5e6a3e2c0ae/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20200323222414-85ca7c5b95cd h1:xhmwyvizuTgC2qz7ZlMluP20uW+C3Rm0FD/WLDX8884= golang.org/x/sys v0.0.0-20200323222414-85ca7c5b95cd/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20200930185726-fdedc70b468f h1:+Nyd8tzPX9R7BWHguqsrbFdRx3WQ/1ib8I44HXV5yTA= golang.org/x/sys v0.0.0-20200930185726-fdedc70b468f/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/text v0.0.0-20160726164857-2910a502d2bf/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= golang.org/x/text v0.3.1-0.20180807135948-17ff2d5776d2/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= golang.org/x/text v0.3.2 h1:tW2bmiBqwgJj/UpqtC8EpXEZVYOwU0yG4iWbprSVAcs= golang.org/x/text v0.3.2/go.mod h1:bEr9sfX3Q8Zfm5fL9x+3itogRgK3+ptLWKqgva+5dAk= +golang.org/x/text v0.3.3 h1:cokOdA+Jmi5PJGXLlLllQSgYigAEfHXJAERHVMaCc2k= golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= golang.org/x/time v0.0.0-20180412165947-fbb02b2291d2/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ= golang.org/x/time v0.0.0-20181108054448-85acf8d2951c/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ= diff --git a/go/test/endtoend/vtgate/information_schema_test.go b/go/test/endtoend/vtgate/information_schema_test.go index ada2844a8e5..4116efaf086 100644 --- a/go/test/endtoend/vtgate/information_schema_test.go +++ b/go/test/endtoend/vtgate/information_schema_test.go @@ -114,3 +114,14 @@ func TestInformationSchemaQueryGetsRoutedToTheRightTableAndKeyspace(t *testing.T result := exec(t, conn, "SELECT * FROM information_schema.tables WHERE table_schema = database() and table_name='t1000'") assert.NotEmpty(t, result.Rows) } + +func TestFKConstraintUsingInformationSchema(t *testing.T) { + defer cluster.PanicHandler(t) + ctx := context.Background() + conn, err := mysql.Connect(ctx, &vtParams) + require.NoError(t, err) + defer conn.Close() + + query := "select fk.referenced_table_name as to_table, fk.referenced_column_name as primary_key, fk.column_name as `column`, fk.constraint_name as name, rc.update_rule as on_update, rc.delete_rule as on_delete from information_schema.referential_constraints as rc join information_schema.key_column_usage as fk using (constraint_schema, constraint_name) where fk.referenced_column_name is not null and fk.table_schema = database() and fk.table_name = 't7_fk' and rc.constraint_schema = database() and rc.table_name = 't7_fk'" + assertMatches(t, conn, query, `[[VARCHAR("t7_xxhash") VARCHAR("uid") VARCHAR("t7_uid") VARCHAR("t7_fk_ibfk_1") VARCHAR("CASCADE") VARCHAR("SET NULL")]]`) +} diff --git a/go/test/endtoend/vtgate/main_test.go b/go/test/endtoend/vtgate/main_test.go index a9a2ec641b9..5f4422984e9 100644 --- a/go/test/endtoend/vtgate/main_test.go +++ b/go/test/endtoend/vtgate/main_test.go @@ -127,7 +127,16 @@ create table t7_xxhash_idx( phone bigint, keyspace_id varbinary(50), primary key(phone, keyspace_id) -) Engine=InnoDB;` +) Engine=InnoDB; + +create table t7_fk( + id bigint, + t7_uid varchar(50), + primary key(id), + CONSTRAINT t7_fk_ibfk_1 foreign key (t7_uid) references t7_xxhash(uid) + on delete set null on update cascade +) Engine=InnoDB; +` VSchema = ` { @@ -353,6 +362,14 @@ create table t7_xxhash_idx( "name": "unicode_loose_xxhash" } ] + }, + "t7_fk": { + "column_vindexes": [ + { + "column": "t7_uid", + "name": "unicode_loose_xxhash" + } + ] } } }` diff --git a/go/vt/sqlparser/ast_funcs.go b/go/vt/sqlparser/ast_funcs.go index 51b7762043e..f54ba48d5b0 100644 --- a/go/vt/sqlparser/ast_funcs.go +++ b/go/vt/sqlparser/ast_funcs.go @@ -21,6 +21,9 @@ import ( "encoding/json" "strings" + vtrpcpb "vitess.io/vitess/go/vt/proto/vtrpc" + "vitess.io/vitess/go/vt/vterrors" + "vitess.io/vitess/go/vt/log" "vitess.io/vitess/go/sqltypes" @@ -337,6 +340,20 @@ func (node *AliasedTableExpr) RemoveHints() *AliasedTableExpr { return &noHints } +//TableName returns a TableName pointing to this table expr +func (node *AliasedTableExpr) TableName() (TableName, error) { + if !node.As.IsEmpty() { + return TableName{Name: node.As}, nil + } + + tableName, ok := node.Expr.(TableName) + if !ok { + return TableName{}, vterrors.Errorf(vtrpcpb.Code_INTERNAL, "BUG: the AST has changed. This should not be possible") + } + + return tableName, nil +} + // IsEmpty returns true if TableName is nil or empty. func (node TableName) IsEmpty() bool { // If Name is empty, Qualifier is also empty. @@ -509,6 +526,17 @@ func NewColName(str string) *ColName { } } +// NewColNameWithQualifier makes a new ColName pointing to a specific table +func NewColNameWithQualifier(identifier string, table TableName) *ColName { + return &ColName{ + Name: NewColIdent(identifier), + Qualifier: TableName{ + Name: NewTableIdent(table.Name.String()), + Qualifier: NewTableIdent(table.Qualifier.String()), + }, + } +} + //NewSelect is used to create a select statement func NewSelect(comments Comments, exprs SelectExprs, selectOptions []string, from TableExprs, where *Where, groupBy GroupBy, having *Where) *Select { var cache *bool diff --git a/go/vt/sqlparser/ast_rewriting.go b/go/vt/sqlparser/ast_rewriting.go index ef8a2580273..f78bf8bcb32 100644 --- a/go/vt/sqlparser/ast_rewriting.go +++ b/go/vt/sqlparser/ast_rewriting.go @@ -20,6 +20,10 @@ import ( querypb "vitess.io/vitess/go/vt/proto/query" vtrpcpb "vitess.io/vitess/go/vt/proto/vtrpc" "vitess.io/vitess/go/vt/vterrors" + + "strings" + + "vitess.io/vitess/go/vt/sysvars" ) // RewriteASTResult contains the rewritten ast and meta information about it @@ -55,3 +59,231 @@ func RewriteAST(in Statement) (*RewriteASTResult, error) { } return r, nil } + +func shouldRewriteDatabaseFunc(in Statement) bool { + selct, ok := in.(*Select) + if !ok { + return false + } + if len(selct.From) != 1 { + return false + } + aliasedTable, ok := selct.From[0].(*AliasedTableExpr) + if !ok { + return false + } + tableName, ok := aliasedTable.Expr.(TableName) + if !ok { + return false + } + return tableName.Name.String() == "dual" +} + +type expressionRewriter struct { + bindVars *BindVarNeeds + shouldRewriteDatabaseFunc bool + err error + + // we need to know this to make a decision if we can safely rewrite JOIN USING => JOIN ON + hasStarInSelect bool +} + +func newExpressionRewriter() *expressionRewriter { + return &expressionRewriter{bindVars: &BindVarNeeds{}} +} + +const ( + //LastInsertIDName is a reserved bind var name for last_insert_id() + LastInsertIDName = "__lastInsertId" + + //DBVarName is a reserved bind var name for database() + DBVarName = "__vtdbname" + + //FoundRowsName is a reserved bind var name for found_rows() + FoundRowsName = "__vtfrows" + + //RowCountName is a reserved bind var name for row_count() + RowCountName = "__vtrcount" + + //UserDefinedVariableName is what we prepend bind var names for user defined variables + UserDefinedVariableName = "__vtudv" +) + +func (er *expressionRewriter) rewriteAliasedExpr(node *AliasedExpr) (*BindVarNeeds, error) { + inner := newExpressionRewriter() + inner.shouldRewriteDatabaseFunc = er.shouldRewriteDatabaseFunc + tmp := Rewrite(node.Expr, inner.rewrite, nil) + newExpr, ok := tmp.(Expr) + if !ok { + return nil, vterrors.Errorf(vtrpcpb.Code_INTERNAL, "failed to rewrite AST. function expected to return Expr returned a %s", String(tmp)) + } + node.Expr = newExpr + return inner.bindVars, nil +} + +func (er *expressionRewriter) rewrite(cursor *Cursor) bool { + switch node := cursor.Node().(type) { + // select last_insert_id() -> select :__lastInsertId as `last_insert_id()` + case *Select: + for _, col := range node.SelectExprs { + _, hasStar := col.(*StarExpr) + if hasStar { + er.hasStarInSelect = true + } + + aliasedExpr, ok := col.(*AliasedExpr) + if ok && aliasedExpr.As.IsEmpty() { + buf := NewTrackedBuffer(nil) + aliasedExpr.Expr.Format(buf) + innerBindVarNeeds, err := er.rewriteAliasedExpr(aliasedExpr) + if err != nil { + er.err = err + return false + } + if innerBindVarNeeds.HasRewrites() { + aliasedExpr.As = NewColIdent(buf.String()) + } + er.bindVars.MergeWith(innerBindVarNeeds) + } + } + case *FuncExpr: + er.funcRewrite(cursor, node) + case *ColName: + switch node.Name.at { + case SingleAt: + er.udvRewrite(cursor, node) + case DoubleAt: + er.sysVarRewrite(cursor, node) + } + case *Subquery: + er.unnestSubQueries(cursor, node) + + case JoinCondition: + if node.Using != nil && !er.hasStarInSelect { + joinTableExpr, ok := cursor.Parent().(*JoinTableExpr) + if !ok { + // this is not possible with the current AST + break + } + leftTable, leftOk := joinTableExpr.LeftExpr.(*AliasedTableExpr) + rightTable, rightOk := joinTableExpr.RightExpr.(*AliasedTableExpr) + if !(leftOk && rightOk) { + // we only deal with simple FROM A JOIN B USING queries at the moment + break + } + lft, err := leftTable.TableName() + if err != nil { + er.err = err + break + } + rgt, err := rightTable.TableName() + if err != nil { + er.err = err + break + } + newCondition := JoinCondition{} + for _, colIdent := range node.Using { + lftCol := NewColNameWithQualifier(colIdent.String(), lft) + rgtCol := NewColNameWithQualifier(colIdent.String(), rgt) + cmp := &ComparisonExpr{ + Operator: EqualOp, + Left: lftCol, + Right: rgtCol, + } + if newCondition.On == nil { + newCondition.On = cmp + } else { + newCondition.On = &AndExpr{Left: newCondition.On, Right: cmp} + } + } + cursor.Replace(newCondition) + } + + } + return true +} + +func (er *expressionRewriter) sysVarRewrite(cursor *Cursor, node *ColName) { + lowered := node.Name.Lowered() + switch lowered { + case sysvars.Autocommit.Name, + sysvars.ClientFoundRows.Name, + sysvars.SkipQueryPlanCache.Name, + sysvars.SQLSelectLimit.Name, + sysvars.TransactionMode.Name, + sysvars.Workload.Name, + sysvars.DDLStrategy.Name, + sysvars.ReadAfterWriteGTID.Name, + sysvars.ReadAfterWriteTimeOut.Name, + sysvars.SessionTrackGTIDs.Name: + cursor.Replace(bindVarExpression("__vt" + lowered)) + er.bindVars.AddSysVar(lowered) + } +} + +func (er *expressionRewriter) udvRewrite(cursor *Cursor, node *ColName) { + udv := strings.ToLower(node.Name.CompliantName()) + cursor.Replace(bindVarExpression(UserDefinedVariableName + udv)) + er.bindVars.AddUserDefVar(udv) +} + +var funcRewrites = map[string]string{ + "last_insert_id": LastInsertIDName, + "database": DBVarName, + "schema": DBVarName, + "found_rows": FoundRowsName, + "row_count": RowCountName, +} + +func (er *expressionRewriter) funcRewrite(cursor *Cursor, node *FuncExpr) { + bindVar, found := funcRewrites[node.Name.Lowered()] + if found { + if bindVar == DBVarName && !er.shouldRewriteDatabaseFunc { + return + } + if len(node.Exprs) > 0 { + er.err = vterrors.Errorf(vtrpcpb.Code_UNIMPLEMENTED, "Argument to %s() not supported", node.Name.Lowered()) + return + } + cursor.Replace(bindVarExpression(bindVar)) + er.bindVars.AddFuncResult(bindVar) + } +} + +func (er *expressionRewriter) unnestSubQueries(cursor *Cursor, subquery *Subquery) { + sel, isSimpleSelect := subquery.Select.(*Select) + if !isSimpleSelect { + return + } + + if !(len(sel.SelectExprs) != 1 || + len(sel.OrderBy) != 0 || + len(sel.GroupBy) != 0 || + len(sel.From) != 1 || + sel.Where == nil || + sel.Having == nil || + sel.Limit == nil) && sel.Lock == NoLock { + return + } + aliasedTable, ok := sel.From[0].(*AliasedTableExpr) + if !ok { + return + } + table, ok := aliasedTable.Expr.(TableName) + if !ok || table.Name.String() != "dual" { + return + } + expr, ok := sel.SelectExprs[0].(*AliasedExpr) + if !ok { + return + } + er.bindVars.NoteRewrite() + // we need to make sure that the inner expression also gets rewritten, + // so we fire off another rewriter traversal here + rewrittenExpr := Rewrite(expr.Expr, er.rewrite, nil) + cursor.Replace(rewrittenExpr) +} + +func bindVarExpression(name string) Expr { + return NewArgument([]byte(":" + name)) +} diff --git a/go/vt/sqlparser/expression_rewriting_test.go b/go/vt/sqlparser/ast_rewriting_test.go similarity index 93% rename from go/vt/sqlparser/expression_rewriting_test.go rename to go/vt/sqlparser/ast_rewriting_test.go index 25a3a0abbc2..a40e2d7435a 100644 --- a/go/vt/sqlparser/expression_rewriting_test.go +++ b/go/vt/sqlparser/ast_rewriting_test.go @@ -154,6 +154,16 @@ func TestRewrites(in *testing.T) { in: `select * from user where col = @@read_after_write_gtid OR col = @@read_after_write_timeout OR col = @@session_track_gtids`, expected: "select * from user where col = :__vtread_after_write_gtid or col = :__vtread_after_write_timeout or col = :__vtsession_track_gtids", rawGTID: true, rawTimeout: true, sessTrackGTID: true, + }, { + in: "SELECT a.col, b.col FROM A JOIN B USING (id)", + expected: "SELECT a.col, b.col FROM A JOIN B ON A.id = B.id", + }, { + in: "SELECT a.col, b.col FROM A JOIN B USING (id1,id2,id3)", + expected: "SELECT a.col, b.col FROM A JOIN B ON A.id1 = B.id1 AND A.id2 = B.id2 AND A.id3 = B.id3", + }, { + // SELECT * behaves different depending the join type used, so if that has been used, we won't rewrite + in: "SELECT * FROM A JOIN B USING (id1,id2,id3)", + expected: "SELECT * FROM A JOIN B USING (id1,id2,id3)", }} for _, tc := range tests { diff --git a/go/vt/sqlparser/expression_rewriting.go b/go/vt/sqlparser/expression_rewriting.go deleted file mode 100644 index 95e6f588adf..00000000000 --- a/go/vt/sqlparser/expression_rewriting.go +++ /dev/null @@ -1,204 +0,0 @@ -/* -Copyright 2019 The Vitess Authors. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -*/ - -package sqlparser - -import ( - "strings" - - "vitess.io/vitess/go/vt/sysvars" - - "vitess.io/vitess/go/vt/proto/vtrpc" - "vitess.io/vitess/go/vt/vterrors" -) - -func shouldRewriteDatabaseFunc(in Statement) bool { - selct, ok := in.(*Select) - if !ok { - return false - } - if len(selct.From) != 1 { - return false - } - aliasedTable, ok := selct.From[0].(*AliasedTableExpr) - if !ok { - return false - } - tableName, ok := aliasedTable.Expr.(TableName) - if !ok { - return false - } - return tableName.Name.String() == "dual" -} - -type expressionRewriter struct { - bindVars *BindVarNeeds - shouldRewriteDatabaseFunc bool - err error -} - -func newExpressionRewriter() *expressionRewriter { - return &expressionRewriter{bindVars: &BindVarNeeds{}} -} - -const ( - //LastInsertIDName is a reserved bind var name for last_insert_id() - LastInsertIDName = "__lastInsertId" - - //DBVarName is a reserved bind var name for database() - DBVarName = "__vtdbname" - - //FoundRowsName is a reserved bind var name for found_rows() - FoundRowsName = "__vtfrows" - - //RowCountName is a reserved bind var name for row_count() - RowCountName = "__vtrcount" - - //UserDefinedVariableName is what we prepend bind var names for user defined variables - UserDefinedVariableName = "__vtudv" -) - -func (er *expressionRewriter) rewriteAliasedExpr(node *AliasedExpr) (*BindVarNeeds, error) { - inner := newExpressionRewriter() - inner.shouldRewriteDatabaseFunc = er.shouldRewriteDatabaseFunc - tmp := Rewrite(node.Expr, inner.rewrite, nil) - newExpr, ok := tmp.(Expr) - if !ok { - return nil, vterrors.Errorf(vtrpc.Code_INTERNAL, "failed to rewrite AST. function expected to return Expr returned a %s", String(tmp)) - } - node.Expr = newExpr - return inner.bindVars, nil -} - -func (er *expressionRewriter) rewrite(cursor *Cursor) bool { - switch node := cursor.Node().(type) { - // select last_insert_id() -> select :__lastInsertId as `last_insert_id()` - case *Select: - for _, col := range node.SelectExprs { - aliasedExpr, ok := col.(*AliasedExpr) - if ok && aliasedExpr.As.IsEmpty() { - buf := NewTrackedBuffer(nil) - aliasedExpr.Expr.Format(buf) - innerBindVarNeeds, err := er.rewriteAliasedExpr(aliasedExpr) - if err != nil { - er.err = err - return false - } - if innerBindVarNeeds.HasRewrites() { - aliasedExpr.As = NewColIdent(buf.String()) - } - er.bindVars.MergeWith(innerBindVarNeeds) - } - } - case *FuncExpr: - er.funcRewrite(cursor, node) - case *ColName: - switch node.Name.at { - case SingleAt: - er.udvRewrite(cursor, node) - case DoubleAt: - er.sysVarRewrite(cursor, node) - } - case *Subquery: - er.unnestSubQueries(cursor, node) - } - return true -} - -func (er *expressionRewriter) sysVarRewrite(cursor *Cursor, node *ColName) { - lowered := node.Name.Lowered() - switch lowered { - case sysvars.Autocommit.Name, - sysvars.ClientFoundRows.Name, - sysvars.SkipQueryPlanCache.Name, - sysvars.SQLSelectLimit.Name, - sysvars.TransactionMode.Name, - sysvars.Workload.Name, - sysvars.DDLStrategy.Name, - sysvars.ReadAfterWriteGTID.Name, - sysvars.ReadAfterWriteTimeOut.Name, - sysvars.SessionTrackGTIDs.Name: - cursor.Replace(bindVarExpression("__vt" + lowered)) - er.bindVars.AddSysVar(lowered) - } -} - -func (er *expressionRewriter) udvRewrite(cursor *Cursor, node *ColName) { - udv := strings.ToLower(node.Name.CompliantName()) - cursor.Replace(bindVarExpression(UserDefinedVariableName + udv)) - er.bindVars.AddUserDefVar(udv) -} - -var funcRewrites = map[string]string{ - "last_insert_id": LastInsertIDName, - "database": DBVarName, - "schema": DBVarName, - "found_rows": FoundRowsName, - "row_count": RowCountName, -} - -func (er *expressionRewriter) funcRewrite(cursor *Cursor, node *FuncExpr) { - bindVar, found := funcRewrites[node.Name.Lowered()] - if found { - if bindVar == DBVarName && !er.shouldRewriteDatabaseFunc { - return - } - if len(node.Exprs) > 0 { - er.err = vterrors.Errorf(vtrpc.Code_UNIMPLEMENTED, "Argument to %s() not supported", node.Name.Lowered()) - return - } - cursor.Replace(bindVarExpression(bindVar)) - er.bindVars.AddFuncResult(bindVar) - } -} - -func (er *expressionRewriter) unnestSubQueries(cursor *Cursor, subquery *Subquery) { - sel, isSimpleSelect := subquery.Select.(*Select) - if !isSimpleSelect { - return - } - - if !(len(sel.SelectExprs) != 1 || - len(sel.OrderBy) != 0 || - len(sel.GroupBy) != 0 || - len(sel.From) != 1 || - sel.Where == nil || - sel.Having == nil || - sel.Limit == nil) && sel.Lock == NoLock { - return - } - aliasedTable, ok := sel.From[0].(*AliasedTableExpr) - if !ok { - return - } - table, ok := aliasedTable.Expr.(TableName) - if !ok || table.Name.String() != "dual" { - return - } - expr, ok := sel.SelectExprs[0].(*AliasedExpr) - if !ok { - return - } - er.bindVars.NoteRewrite() - // we need to make sure that the inner expression also gets rewritten, - // so we fire off another rewriter traversal here - rewrittenExpr := Rewrite(expr.Expr, er.rewrite, nil) - cursor.Replace(rewrittenExpr) -} - -func bindVarExpression(name string) Expr { - return NewArgument([]byte(":" + name)) -} diff --git a/go/vt/vtgate/planbuilder/join.go b/go/vt/vtgate/planbuilder/join.go index 71b9e74c1ee..dc4ffe26732 100644 --- a/go/vt/vtgate/planbuilder/join.go +++ b/go/vt/vtgate/planbuilder/join.go @@ -98,7 +98,7 @@ func newJoin(lpb, rpb *primitiveBuilder, ajoin *sqlparser.JoinTableExpr) error { return err } case ajoin.Condition.Using != nil: - return errors.New("unsupported: join with USING(column_list) clause") + return errors.New("unsupported: join with USING(column_list) clause for complex queries") } } lpb.plan = &join{ diff --git a/go/vt/vtgate/planbuilder/route.go b/go/vt/vtgate/planbuilder/route.go index f99e4360f4c..e6eab2b8337 100644 --- a/go/vt/vtgate/planbuilder/route.go +++ b/go/vt/vtgate/planbuilder/route.go @@ -377,7 +377,7 @@ func (rb *route) JoinCanMerge(pb *primitiveBuilder, rrb *route, ajoin *sqlparser return true } switch rb.eroute.Opcode { - case engine.SelectUnsharded, engine.SelectDBA: + case engine.SelectUnsharded: return rb.eroute.Opcode == rrb.eroute.Opcode case engine.SelectEqualUnique: // Check if they target the same shard. @@ -386,7 +386,7 @@ func (rb *route) JoinCanMerge(pb *primitiveBuilder, rrb *route, ajoin *sqlparser } case engine.SelectReference: return true - case engine.SelectNext: + case engine.SelectNext, engine.SelectDBA: return false } if ajoin == nil { diff --git a/go/vt/vtgate/planbuilder/route_test.go b/go/vt/vtgate/planbuilder/route_test.go index ca133f946ba..cf60f09e28f 100644 --- a/go/vt/vtgate/planbuilder/route_test.go +++ b/go/vt/vtgate/planbuilder/route_test.go @@ -50,7 +50,7 @@ func TestJoinCanMerge(t *testing.T) { {false, false, false, false, false, false, false, false, true, false}, {false, false, false, false, false, false, false, false, true, false}, {false, false, false, false, false, false, false, false, true, false}, - {false, false, false, false, false, false, false, true, true, false}, + {false, false, false, false, false, false, false, false, true, false}, {true, true, true, true, true, true, true, true, true, true}, {false, false, false, false, false, false, false, false, true, false}, } diff --git a/go/vt/vtgate/planbuilder/testdata/from_cases.txt b/go/vt/vtgate/planbuilder/testdata/from_cases.txt index ca27e5041cc..b85e65681f4 100644 --- a/go/vt/vtgate/planbuilder/testdata/from_cases.txt +++ b/go/vt/vtgate/planbuilder/testdata/from_cases.txt @@ -296,19 +296,37 @@ } # ',' join information_schema -"select * from information_schema.a, information_schema.b" +"select a.id,b.id from information_schema.a as a, information_schema.b as b" { "QueryType": "SELECT", - "Original": "select * from information_schema.a, information_schema.b", + "Original": "select a.id,b.id from information_schema.a as a, information_schema.b as b", "Instructions": { - "OperatorType": "Route", - "Variant": "SelectDBA", - "Keyspace": { - "Name": "main", - "Sharded": false - }, - "FieldQuery": "select * from information_schema.a, information_schema.b where 1 != 1", - "Query": "select * from information_schema.a, information_schema.b" + "OperatorType": "Join", + "Variant": "Join", + "JoinColumnIndexes": "-1,1", + "TableName": "_", + "Inputs": [ + { + "OperatorType": "Route", + "Variant": "SelectDBA", + "Keyspace": { + "Name": "main", + "Sharded": false + }, + "FieldQuery": "select a.id from information_schema.a as a where 1 != 1", + "Query": "select a.id from information_schema.a as a" + }, + { + "OperatorType": "Route", + "Variant": "SelectDBA", + "Keyspace": { + "Name": "main", + "Sharded": false + }, + "FieldQuery": "select b.id from information_schema.b as b where 1 != 1", + "Query": "select b.id from information_schema.b as b" + } + ] } } @@ -2079,6 +2097,43 @@ } } +# join with USING construct +"select user.id from user join user_extra using(id)" +{ + "QueryType": "SELECT", + "Original": "select user.id from user join user_extra using(id)", + "Instructions": { + "OperatorType": "Join", + "Variant": "Join", + "JoinColumnIndexes": "-1", + "TableName": "user_user_extra", + "Inputs": [ + { + "OperatorType": "Route", + "Variant": "SelectScatter", + "Keyspace": { + "Name": "user", + "Sharded": true + }, + "FieldQuery": "select user.id from user where 1 != 1", + "Query": "select user.id from user", + "Table": "user" + }, + { + "OperatorType": "Route", + "Variant": "SelectScatter", + "Keyspace": { + "Name": "user", + "Sharded": true + }, + "FieldQuery": "select 1 from user_extra where 1 != 1", + "Query": "select 1 from user_extra where user_extra.id = :user_id", + "Table": "user_extra" + } + ] + } +} + # verify ',' vs JOIN precedence "select u1.a from unsharded u1, unsharded u2 join unsharded u3 on u1.a = u2.a" "symbol u1.a not found" @@ -2131,3 +2186,75 @@ "Query": "select column_name from information_schema.`columns` where table_schema = schema()" } } + +# information schema join +"select 42 from information_schema.a join information_schema.b" +{ + "QueryType": "SELECT", + "Original": "select 42 from information_schema.a join information_schema.b", + "Instructions": { + "OperatorType": "Join", + "Variant": "Join", + "JoinColumnIndexes": "-1", + "TableName": "_", + "Inputs": [ + { + "OperatorType": "Route", + "Variant": "SelectDBA", + "Keyspace": { + "Name": "main", + "Sharded": false + }, + "FieldQuery": "select 42 from information_schema.a where 1 != 1", + "Query": "select 42 from information_schema.a" + }, + { + "OperatorType": "Route", + "Variant": "SelectDBA", + "Keyspace": { + "Name": "main", + "Sharded": false + }, + "FieldQuery": "select 1 from information_schema.b where 1 != 1", + "Query": "select 1 from information_schema.b" + } + ] + } +} + +# rails query +"select fk.referenced_table_name as to_table, fk.referenced_column_name as primary_key, fk.column_name as `column`, fk.constraint_name as name, rc.update_rule as on_update, rc.delete_rule as on_delete from information_schema.referential_constraints as rc join information_schema.key_column_usage as fk using (constraint_schema, constraint_name) where fk.referenced_column_name is not null and fk.table_schema = database() and fk.table_name = ':vtg1' and rc.constraint_schema = database() and rc.table_name = ':vtg1'" +{ + "QueryType": "SELECT", + "Original": "select fk.referenced_table_name as to_table, fk.referenced_column_name as primary_key, fk.column_name as `column`, fk.constraint_name as name, rc.update_rule as on_update, rc.delete_rule as on_delete from information_schema.referential_constraints as rc join information_schema.key_column_usage as fk using (constraint_schema, constraint_name) where fk.referenced_column_name is not null and fk.table_schema = database() and fk.table_name = ':vtg1' and rc.constraint_schema = database() and rc.table_name = ':vtg1'", + "Instructions": { + "OperatorType": "Join", + "Variant": "Join", + "JoinColumnIndexes": "1,2,3,4,-1,-2", + "TableName": "_", + "Inputs": [ + { + "OperatorType": "Route", + "Variant": "SelectDBA", + "Keyspace": { + "Name": "main", + "Sharded": false + }, + "FieldQuery": "select rc.update_rule as on_update, rc.delete_rule as on_delete, rc.constraint_schema, rc.constraint_name from information_schema.referential_constraints as rc where 1 != 1", + "Query": "select rc.update_rule as on_update, rc.delete_rule as on_delete, rc.constraint_schema, rc.constraint_name from information_schema.referential_constraints as rc where rc.constraint_schema = database() and rc.table_name = :__vttablename", + "SysTableTableName": "VARBINARY(\":vtg1\")" + }, + { + "OperatorType": "Route", + "Variant": "SelectDBA", + "Keyspace": { + "Name": "main", + "Sharded": false + }, + "FieldQuery": "select fk.referenced_table_name as to_table, fk.referenced_column_name as primary_key, fk.column_name as `column`, fk.constraint_name as `name` from information_schema.key_column_usage as fk where 1 != 1", + "Query": "select fk.referenced_table_name as to_table, fk.referenced_column_name as primary_key, fk.column_name as `column`, fk.constraint_name as `name` from information_schema.key_column_usage as fk where fk.constraint_schema = :rc_constraint_schema and fk.constraint_name = :rc_constraint_name and fk.referenced_column_name is not null and fk.table_schema = database() and fk.table_name = :__vttablename", + "SysTableTableName": "VARBINARY(\":vtg1\")" + } + ] + } +} diff --git a/go/vt/vtgate/planbuilder/testdata/unsupported_cases.txt b/go/vt/vtgate/planbuilder/testdata/unsupported_cases.txt index 966f6132e80..31593d3c99d 100644 --- a/go/vt/vtgate/planbuilder/testdata/unsupported_cases.txt +++ b/go/vt/vtgate/planbuilder/testdata/unsupported_cases.txt @@ -42,6 +42,14 @@ "select * from user natural join user_extra" "unsupported: natural join" +# join with USING construct +"select * from user join user_extra using(id)" +"unsupported: join with USING(column_list) clause for complex queries" + +# join with USING construct with 3 tables +"select user.id from user join user_extra using(id) join music using(id2)" +"unsupported: join with USING(column_list) clause for complex queries" + # natural left join "select * from user natural left join user_extra" "unsupported: natural left join" @@ -50,10 +58,6 @@ "select * from user natural right join user_extra" "unsupported: natural right join" -# join with USING construct -"select * from user join user_extra using(id)" -"unsupported: join with USING(column_list) clause" - # left join with expressions "select user.id, user_extra.col+1 from user left join user_extra on user.col = user_extra.col" "unsupported: cross-shard left join and column expressions"