Skip to content

Commit

Permalink
Merge pull request #6660 from planetscale/join-using
Browse files Browse the repository at this point in the history
Rewrite joins written with the USING construct
  • Loading branch information
systay authored Dec 8, 2020
2 parents c467dc4 + 044de24 commit fc0fdeb
Show file tree
Hide file tree
Showing 12 changed files with 453 additions and 223 deletions.
5 changes: 5 additions & 0 deletions go.sum
Original file line number Diff line number Diff line change
Expand Up @@ -684,6 +684,7 @@ golang.org/x/crypto v0.0.0-20191206172530-e9b2fee46413 h1:ULYEB3JvPRE/IfO+9uO7vK
golang.org/x/crypto v0.0.0-20191206172530-e9b2fee46413/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto=
golang.org/x/crypto v0.0.0-20200220183623-bac4c82f6975 h1:/Tl7pH94bvbAAHBdZJT947M/+gp0+CqQXDtMRC0fseo=
golang.org/x/crypto v0.0.0-20200220183623-bac4c82f6975/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto=
golang.org/x/crypto v0.0.0-20200622213623-75b288015ac9 h1:psW17arqaxU48Z5kZ0CQnkZWQJsqcURM6tKiBApRjXI=
golang.org/x/crypto v0.0.0-20200622213623-75b288015ac9/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto=
golang.org/x/exp v0.0.0-20190121172915-509febef88a4/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA=
golang.org/x/exp v0.0.0-20190125153040-c74c464bbbf2/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA=
Expand Down Expand Up @@ -733,6 +734,7 @@ golang.org/x/net v0.0.0-20200202094626-16171245cfb2/go.mod h1:z5CRVTTTmAJ677TzLL
golang.org/x/net v0.0.0-20200324143707-d3edc9973b7e h1:3G+cUijn7XD+S4eJFddp53Pv7+slrESplyjG25HgL+k=
golang.org/x/net v0.0.0-20200324143707-d3edc9973b7e/go.mod h1:qpuaurCH72eLCgpAm/N6yyVIVM9cpaDIP3A8BGJEC5A=
golang.org/x/net v0.0.0-20200822124328-c89045814202 h1:VvcQYSHwXgi7W+TpUR6A9g6Up98WAHf3f/ulnJ62IyA=
golang.org/x/net v0.0.0-20201021035429-f5854403a974 h1:IX6qOQeG5uLjB/hjjwjedwfjND0hgjPMMyO1RoIXQNI=
golang.org/x/net v0.0.0-20201021035429-f5854403a974/go.mod h1:sp8m0HH+o8qH0wwXwYZr8TS3Oi6o0r6Gce1SSxlDquU=
golang.org/x/oauth2 v0.0.0-20180821212333-d2e6202438be/go.mod h1:N/0e6XlmueqKjAGxoOufVs8QHGRruUQn6yWY3a++T0U=
golang.org/x/oauth2 v0.0.0-20190226205417-e64efc72b421/go.mod h1:gOpvHmFTYa4IltrdGE7lF6nIHvwfUNPOp7c8zoXwtLw=
Expand All @@ -745,6 +747,7 @@ golang.org/x/sync v0.0.0-20190227155943-e225da77a7e6/go.mod h1:RxMgew5VJxzue5/jJ
golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/sync v0.0.0-20190911185100-cd5d95a43a6e h1:vcxGaoTs7kV8m5Np9uUNQin4BrLOthgV7252N8V+FwY=
golang.org/x/sync v0.0.0-20190911185100-cd5d95a43a6e/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/sync v0.0.0-20201020160332-67f06af15bc9 h1:SQFwaSi55rU7vdNs9Yr0Z324VNlrF+0wMqRXT4St8ck=
golang.org/x/sync v0.0.0-20201020160332-67f06af15bc9/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/sys v0.0.0-20170830134202-bb24a47a89ea/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
golang.org/x/sys v0.0.0-20180823144017-11551d06cbcc/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
Expand Down Expand Up @@ -780,12 +783,14 @@ golang.org/x/sys v0.0.0-20200124204421-9fbb57f87de9/go.mod h1:h1NjWce9XRLGQEsW7w
golang.org/x/sys v0.0.0-20200223170610-d5e6a3e2c0ae/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20200323222414-85ca7c5b95cd h1:xhmwyvizuTgC2qz7ZlMluP20uW+C3Rm0FD/WLDX8884=
golang.org/x/sys v0.0.0-20200323222414-85ca7c5b95cd/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20200930185726-fdedc70b468f h1:+Nyd8tzPX9R7BWHguqsrbFdRx3WQ/1ib8I44HXV5yTA=
golang.org/x/sys v0.0.0-20200930185726-fdedc70b468f/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/text v0.0.0-20160726164857-2910a502d2bf/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
golang.org/x/text v0.3.1-0.20180807135948-17ff2d5776d2/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
golang.org/x/text v0.3.2 h1:tW2bmiBqwgJj/UpqtC8EpXEZVYOwU0yG4iWbprSVAcs=
golang.org/x/text v0.3.2/go.mod h1:bEr9sfX3Q8Zfm5fL9x+3itogRgK3+ptLWKqgva+5dAk=
golang.org/x/text v0.3.3 h1:cokOdA+Jmi5PJGXLlLllQSgYigAEfHXJAERHVMaCc2k=
golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
golang.org/x/time v0.0.0-20180412165947-fbb02b2291d2/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ=
golang.org/x/time v0.0.0-20181108054448-85acf8d2951c/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ=
Expand Down
11 changes: 11 additions & 0 deletions go/test/endtoend/vtgate/information_schema_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -114,3 +114,14 @@ func TestInformationSchemaQueryGetsRoutedToTheRightTableAndKeyspace(t *testing.T
result := exec(t, conn, "SELECT * FROM information_schema.tables WHERE table_schema = database() and table_name='t1000'")
assert.NotEmpty(t, result.Rows)
}

func TestFKConstraintUsingInformationSchema(t *testing.T) {
defer cluster.PanicHandler(t)
ctx := context.Background()
conn, err := mysql.Connect(ctx, &vtParams)
require.NoError(t, err)
defer conn.Close()

query := "select fk.referenced_table_name as to_table, fk.referenced_column_name as primary_key, fk.column_name as `column`, fk.constraint_name as name, rc.update_rule as on_update, rc.delete_rule as on_delete from information_schema.referential_constraints as rc join information_schema.key_column_usage as fk using (constraint_schema, constraint_name) where fk.referenced_column_name is not null and fk.table_schema = database() and fk.table_name = 't7_fk' and rc.constraint_schema = database() and rc.table_name = 't7_fk'"
assertMatches(t, conn, query, `[[VARCHAR("t7_xxhash") VARCHAR("uid") VARCHAR("t7_uid") VARCHAR("t7_fk_ibfk_1") VARCHAR("CASCADE") VARCHAR("SET NULL")]]`)
}
19 changes: 18 additions & 1 deletion go/test/endtoend/vtgate/main_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -127,7 +127,16 @@ create table t7_xxhash_idx(
phone bigint,
keyspace_id varbinary(50),
primary key(phone, keyspace_id)
) Engine=InnoDB;`
) Engine=InnoDB;
create table t7_fk(
id bigint,
t7_uid varchar(50),
primary key(id),
CONSTRAINT t7_fk_ibfk_1 foreign key (t7_uid) references t7_xxhash(uid)
on delete set null on update cascade
) Engine=InnoDB;
`

VSchema = `
{
Expand Down Expand Up @@ -353,6 +362,14 @@ create table t7_xxhash_idx(
"name": "unicode_loose_xxhash"
}
]
},
"t7_fk": {
"column_vindexes": [
{
"column": "t7_uid",
"name": "unicode_loose_xxhash"
}
]
}
}
}`
Expand Down
28 changes: 28 additions & 0 deletions go/vt/sqlparser/ast_funcs.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,9 @@ import (
"encoding/json"
"strings"

vtrpcpb "vitess.io/vitess/go/vt/proto/vtrpc"
"vitess.io/vitess/go/vt/vterrors"

"vitess.io/vitess/go/vt/log"

"vitess.io/vitess/go/sqltypes"
Expand Down Expand Up @@ -337,6 +340,20 @@ func (node *AliasedTableExpr) RemoveHints() *AliasedTableExpr {
return &noHints
}

//TableName returns a TableName pointing to this table expr
func (node *AliasedTableExpr) TableName() (TableName, error) {
if !node.As.IsEmpty() {
return TableName{Name: node.As}, nil
}

tableName, ok := node.Expr.(TableName)
if !ok {
return TableName{}, vterrors.Errorf(vtrpcpb.Code_INTERNAL, "BUG: the AST has changed. This should not be possible")
}

return tableName, nil
}

// IsEmpty returns true if TableName is nil or empty.
func (node TableName) IsEmpty() bool {
// If Name is empty, Qualifier is also empty.
Expand Down Expand Up @@ -509,6 +526,17 @@ func NewColName(str string) *ColName {
}
}

// NewColNameWithQualifier makes a new ColName pointing to a specific table
func NewColNameWithQualifier(identifier string, table TableName) *ColName {
return &ColName{
Name: NewColIdent(identifier),
Qualifier: TableName{
Name: NewTableIdent(table.Name.String()),
Qualifier: NewTableIdent(table.Qualifier.String()),
},
}
}

//NewSelect is used to create a select statement
func NewSelect(comments Comments, exprs SelectExprs, selectOptions []string, from TableExprs, where *Where, groupBy GroupBy, having *Where) *Select {
var cache *bool
Expand Down
232 changes: 232 additions & 0 deletions go/vt/sqlparser/ast_rewriting.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,10 @@ import (
querypb "vitess.io/vitess/go/vt/proto/query"
vtrpcpb "vitess.io/vitess/go/vt/proto/vtrpc"
"vitess.io/vitess/go/vt/vterrors"

"strings"

"vitess.io/vitess/go/vt/sysvars"
)

// RewriteASTResult contains the rewritten ast and meta information about it
Expand Down Expand Up @@ -55,3 +59,231 @@ func RewriteAST(in Statement) (*RewriteASTResult, error) {
}
return r, nil
}

func shouldRewriteDatabaseFunc(in Statement) bool {
selct, ok := in.(*Select)
if !ok {
return false
}
if len(selct.From) != 1 {
return false
}
aliasedTable, ok := selct.From[0].(*AliasedTableExpr)
if !ok {
return false
}
tableName, ok := aliasedTable.Expr.(TableName)
if !ok {
return false
}
return tableName.Name.String() == "dual"
}

type expressionRewriter struct {
bindVars *BindVarNeeds
shouldRewriteDatabaseFunc bool
err error

// we need to know this to make a decision if we can safely rewrite JOIN USING => JOIN ON
hasStarInSelect bool
}

func newExpressionRewriter() *expressionRewriter {
return &expressionRewriter{bindVars: &BindVarNeeds{}}
}

const (
//LastInsertIDName is a reserved bind var name for last_insert_id()
LastInsertIDName = "__lastInsertId"

//DBVarName is a reserved bind var name for database()
DBVarName = "__vtdbname"

//FoundRowsName is a reserved bind var name for found_rows()
FoundRowsName = "__vtfrows"

//RowCountName is a reserved bind var name for row_count()
RowCountName = "__vtrcount"

//UserDefinedVariableName is what we prepend bind var names for user defined variables
UserDefinedVariableName = "__vtudv"
)

func (er *expressionRewriter) rewriteAliasedExpr(node *AliasedExpr) (*BindVarNeeds, error) {
inner := newExpressionRewriter()
inner.shouldRewriteDatabaseFunc = er.shouldRewriteDatabaseFunc
tmp := Rewrite(node.Expr, inner.rewrite, nil)
newExpr, ok := tmp.(Expr)
if !ok {
return nil, vterrors.Errorf(vtrpcpb.Code_INTERNAL, "failed to rewrite AST. function expected to return Expr returned a %s", String(tmp))
}
node.Expr = newExpr
return inner.bindVars, nil
}

func (er *expressionRewriter) rewrite(cursor *Cursor) bool {
switch node := cursor.Node().(type) {
// select last_insert_id() -> select :__lastInsertId as `last_insert_id()`
case *Select:
for _, col := range node.SelectExprs {
_, hasStar := col.(*StarExpr)
if hasStar {
er.hasStarInSelect = true
}

aliasedExpr, ok := col.(*AliasedExpr)
if ok && aliasedExpr.As.IsEmpty() {
buf := NewTrackedBuffer(nil)
aliasedExpr.Expr.Format(buf)
innerBindVarNeeds, err := er.rewriteAliasedExpr(aliasedExpr)
if err != nil {
er.err = err
return false
}
if innerBindVarNeeds.HasRewrites() {
aliasedExpr.As = NewColIdent(buf.String())
}
er.bindVars.MergeWith(innerBindVarNeeds)
}
}
case *FuncExpr:
er.funcRewrite(cursor, node)
case *ColName:
switch node.Name.at {
case SingleAt:
er.udvRewrite(cursor, node)
case DoubleAt:
er.sysVarRewrite(cursor, node)
}
case *Subquery:
er.unnestSubQueries(cursor, node)

case JoinCondition:
if node.Using != nil && !er.hasStarInSelect {
joinTableExpr, ok := cursor.Parent().(*JoinTableExpr)
if !ok {
// this is not possible with the current AST
break
}
leftTable, leftOk := joinTableExpr.LeftExpr.(*AliasedTableExpr)
rightTable, rightOk := joinTableExpr.RightExpr.(*AliasedTableExpr)
if !(leftOk && rightOk) {
// we only deal with simple FROM A JOIN B USING queries at the moment
break
}
lft, err := leftTable.TableName()
if err != nil {
er.err = err
break
}
rgt, err := rightTable.TableName()
if err != nil {
er.err = err
break
}
newCondition := JoinCondition{}
for _, colIdent := range node.Using {
lftCol := NewColNameWithQualifier(colIdent.String(), lft)
rgtCol := NewColNameWithQualifier(colIdent.String(), rgt)
cmp := &ComparisonExpr{
Operator: EqualOp,
Left: lftCol,
Right: rgtCol,
}
if newCondition.On == nil {
newCondition.On = cmp
} else {
newCondition.On = &AndExpr{Left: newCondition.On, Right: cmp}
}
}
cursor.Replace(newCondition)
}

}
return true
}

func (er *expressionRewriter) sysVarRewrite(cursor *Cursor, node *ColName) {
lowered := node.Name.Lowered()
switch lowered {
case sysvars.Autocommit.Name,
sysvars.ClientFoundRows.Name,
sysvars.SkipQueryPlanCache.Name,
sysvars.SQLSelectLimit.Name,
sysvars.TransactionMode.Name,
sysvars.Workload.Name,
sysvars.DDLStrategy.Name,
sysvars.ReadAfterWriteGTID.Name,
sysvars.ReadAfterWriteTimeOut.Name,
sysvars.SessionTrackGTIDs.Name:
cursor.Replace(bindVarExpression("__vt" + lowered))
er.bindVars.AddSysVar(lowered)
}
}

func (er *expressionRewriter) udvRewrite(cursor *Cursor, node *ColName) {
udv := strings.ToLower(node.Name.CompliantName())
cursor.Replace(bindVarExpression(UserDefinedVariableName + udv))
er.bindVars.AddUserDefVar(udv)
}

var funcRewrites = map[string]string{
"last_insert_id": LastInsertIDName,
"database": DBVarName,
"schema": DBVarName,
"found_rows": FoundRowsName,
"row_count": RowCountName,
}

func (er *expressionRewriter) funcRewrite(cursor *Cursor, node *FuncExpr) {
bindVar, found := funcRewrites[node.Name.Lowered()]
if found {
if bindVar == DBVarName && !er.shouldRewriteDatabaseFunc {
return
}
if len(node.Exprs) > 0 {
er.err = vterrors.Errorf(vtrpcpb.Code_UNIMPLEMENTED, "Argument to %s() not supported", node.Name.Lowered())
return
}
cursor.Replace(bindVarExpression(bindVar))
er.bindVars.AddFuncResult(bindVar)
}
}

func (er *expressionRewriter) unnestSubQueries(cursor *Cursor, subquery *Subquery) {
sel, isSimpleSelect := subquery.Select.(*Select)
if !isSimpleSelect {
return
}

if !(len(sel.SelectExprs) != 1 ||
len(sel.OrderBy) != 0 ||
len(sel.GroupBy) != 0 ||
len(sel.From) != 1 ||
sel.Where == nil ||
sel.Having == nil ||
sel.Limit == nil) && sel.Lock == NoLock {
return
}
aliasedTable, ok := sel.From[0].(*AliasedTableExpr)
if !ok {
return
}
table, ok := aliasedTable.Expr.(TableName)
if !ok || table.Name.String() != "dual" {
return
}
expr, ok := sel.SelectExprs[0].(*AliasedExpr)
if !ok {
return
}
er.bindVars.NoteRewrite()
// we need to make sure that the inner expression also gets rewritten,
// so we fire off another rewriter traversal here
rewrittenExpr := Rewrite(expr.Expr, er.rewrite, nil)
cursor.Replace(rewrittenExpr)
}

func bindVarExpression(name string) Expr {
return NewArgument([]byte(":" + name))
}
Original file line number Diff line number Diff line change
Expand Up @@ -154,6 +154,16 @@ func TestRewrites(in *testing.T) {
in: `select * from user where col = @@read_after_write_gtid OR col = @@read_after_write_timeout OR col = @@session_track_gtids`,
expected: "select * from user where col = :__vtread_after_write_gtid or col = :__vtread_after_write_timeout or col = :__vtsession_track_gtids",
rawGTID: true, rawTimeout: true, sessTrackGTID: true,
}, {
in: "SELECT a.col, b.col FROM A JOIN B USING (id)",
expected: "SELECT a.col, b.col FROM A JOIN B ON A.id = B.id",
}, {
in: "SELECT a.col, b.col FROM A JOIN B USING (id1,id2,id3)",
expected: "SELECT a.col, b.col FROM A JOIN B ON A.id1 = B.id1 AND A.id2 = B.id2 AND A.id3 = B.id3",
}, {
// SELECT * behaves different depending the join type used, so if that has been used, we won't rewrite
in: "SELECT * FROM A JOIN B USING (id1,id2,id3)",
expected: "SELECT * FROM A JOIN B USING (id1,id2,id3)",
}}

for _, tc := range tests {
Expand Down
Loading

0 comments on commit fc0fdeb

Please sign in to comment.