From 4f6c4f8d8c3581a75dfb43713dc3ff2914c27524 Mon Sep 17 00:00:00 2001 From: Miguel Molina Date: Wed, 12 Jun 2019 16:57:28 +0200 Subject: [PATCH] sql/analyzer: refactor resolve_natural_joins rule Signed-off-by: Miguel Molina --- go.mod | 1 - go.sum | 3 + sql/analyzer/resolve_natural_joins.go | 321 ++++++++------------------ sql/plan/exchange.go | 10 +- 4 files changed, 100 insertions(+), 235 deletions(-) diff --git a/go.mod b/go.mod index fa3a2941e..d409bd6b2 100644 --- a/go.mod +++ b/go.mod @@ -23,7 +23,6 @@ require ( github.com/stretchr/testify v1.2.2 go.etcd.io/bbolt v1.3.2 golang.org/x/net v0.0.0-20190227022144-312bce6e941f // indirect - google.golang.org/genproto v0.0.0-20180831171423-11092d34479b // indirect google.golang.org/grpc v1.19.0 // indirect gopkg.in/src-d/go-errors.v1 v1.0.0 gopkg.in/yaml.v2 v2.2.2 diff --git a/go.sum b/go.sum index 2cfe4bd07..5632ba0b5 100644 --- a/go.sum +++ b/go.sum @@ -40,6 +40,8 @@ github.com/golang/glog v0.0.0-20160126235308-23def4e6c14b h1:VKtxabqXZkF25pY9ekf github.com/golang/glog v0.0.0-20160126235308-23def4e6c14b/go.mod h1:SBH7ygxi8pfUlaOkMMuAQtPIUF8ecWP5IEl/CR7VP2Q= github.com/golang/mock v1.1.1/go.mod h1:oTYuIxOrZwtPieC+H1uAHpcLFnEyAGVDL/k47Jfbm0A= github.com/golang/protobuf v1.2.0/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U= +github.com/golang/protobuf v1.3.0 h1:kbxbvI4Un1LUWKxufD+BiE6AEExYYgkQLQmLFqA1LFk= +github.com/golang/protobuf v1.3.0/go.mod h1:Qd/q+1AKNOZr9uGQzbzCmRO6sUih6GTPZv6a1/R87v0= github.com/golang/protobuf v1.3.1 h1:YF8+flBXS5eO826T4nzqPrxfhQThhXl0YzfuUPu4SBg= github.com/golang/protobuf v1.3.1/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U= github.com/google/btree v0.0.0-20180813153112-4030bb1f1f0c h1:964Od4U6p2jUkFxvCydnIczKteheJEzHRToSGK3Bnlw= @@ -137,6 +139,7 @@ golang.org/x/crypto v0.0.0-20181203042331-505ab145d0a9 h1:mKdxBk7AujPs8kU4m80U72 golang.org/x/crypto v0.0.0-20181203042331-505ab145d0a9/go.mod h1:6SG95UA2DQfeDnfUPMdvaQW0Q7yPrPDi9nlGo2tz2b4= golang.org/x/lint v0.0.0-20181026193005-c67002cb31c3/go.mod h1:UVdnD1Gm6xHRNCYTkRU2/jEulfH38KcIWyp/GAMgvoE= golang.org/x/net v0.0.0-20180826012351-8a410e7b638d/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= +golang.org/x/net v0.0.0-20180906233101-161cd47e91fd/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= golang.org/x/net v0.0.0-20181023162649-9b4f9f5ad519/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= golang.org/x/net v0.0.0-20190227022144-312bce6e941f h1:tbtX/qtlxzhZjgQue/7u7ygFwDEckd+DmS5+t8FgeKE= golang.org/x/net v0.0.0-20190227022144-312bce6e941f/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= diff --git a/sql/analyzer/resolve_natural_joins.go b/sql/analyzer/resolve_natural_joins.go index 97f2561b3..0ada6eb36 100644 --- a/sql/analyzer/resolve_natural_joins.go +++ b/sql/analyzer/resolve_natural_joins.go @@ -1,266 +1,137 @@ package analyzer import ( - "reflect" + "strings" "github.com/src-d/go-mysql-server/sql" "github.com/src-d/go-mysql-server/sql/expression" "github.com/src-d/go-mysql-server/sql/plan" ) -type transformedJoin struct { - node sql.Node - condCols map[string]*transformedSource -} - -type transformedSource struct { - correct string - wrong []string -} - func resolveNaturalJoins(ctx *sql.Context, a *Analyzer, n sql.Node) (sql.Node, error) { span, _ := ctx.Span("resolve_natural_joins") defer span.Finish() - if n.Resolved() { - return n, nil - } - - var transformed []*transformedJoin - var aliasTables = map[string][]string{} - var colsToUnresolve = map[string]*transformedSource{} - a.Log("resolving natural joins, node of type %T", n) - node, err := n.TransformUp(func(n sql.Node) (sql.Node, error) { - a.Log("transforming node of type: %T", n) + var replacements = make(map[tableCol]tableCol) + var tableAliases = make(map[string]string) - if alias, ok := n.(*plan.TableAlias); ok { - table := alias.Child.(*plan.ResolvedTable).Name() - aliasTables[alias.Name()] = append(aliasTables[alias.Name()], table) + return n.TransformUp(func(n sql.Node) (sql.Node, error) { + switch n := n.(type) { + case *plan.TableAlias: + alias := n.Name() + table := n.Child.(*plan.ResolvedTable).Name() + tableAliases[strings.ToLower(alias)] = table return n, nil - } - - if n.Resolved() { + case *plan.NaturalJoin: + return resolveNaturalJoin(n, replacements) + case sql.Expressioner: + return replaceExpressions(n, replacements, tableAliases) + default: return n, nil } + }) +} - join, ok := n.(*plan.NaturalJoin) - if !ok { - return n, nil - } - - // we need both leaves resolved before resolving this one - if !join.Left.Resolved() || !join.Right.Resolved() { - return n, nil - } - - leftSchema, rightSchema := join.Left.Schema(), join.Right.Schema() - - var conditions, common, left, right []sql.Expression - var seen = make(map[string]struct{}) - - for i, lcol := range leftSchema { - var found bool - leftCol := expression.NewGetFieldWithTable( - i, - lcol.Type, - lcol.Source, - lcol.Name, - lcol.Nullable, - ) - - for j, rcol := range rightSchema { - if lcol.Name == rcol.Name { - common = append(common, leftCol) - - conditions = append( - conditions, - expression.NewEquals( - leftCol, - expression.NewGetFieldWithTable( - len(leftSchema)+j, - rcol.Type, - rcol.Source, - rcol.Name, - rcol.Nullable, - ), - ), - ) - - found = true - seen[lcol.Name] = struct{}{} - if source, ok := colsToUnresolve[lcol.Name]; ok { - source.correct = lcol.Source - source.wrong = append(source.wrong, rcol.Source) - } else { - colsToUnresolve[lcol.Name] = &transformedSource{ - correct: lcol.Source, - wrong: []string{rcol.Source}, - } - } - - break - } - } +func resolveNaturalJoin( + n *plan.NaturalJoin, + replacements map[tableCol]tableCol, +) (sql.Node, error) { + // Both sides of the natural join need to be resolved in order to resolve + // the natural join itself. + if !n.Left.Resolved() || !n.Right.Resolved() { + return n, nil + } - if !found { - left = append(left, leftCol) + leftSchema := n.Left.Schema() + rightSchema := n.Right.Schema() + + var conditions, common, left, right []sql.Expression + for i, lcol := range leftSchema { + leftCol := expression.NewGetFieldWithTable( + i, + lcol.Type, + lcol.Source, + lcol.Name, + lcol.Nullable, + ) + if idx, rcol := findCol(rightSchema, lcol.Name); rcol != nil { + common = append(common, leftCol) + replacements[tableCol{strings.ToLower(rcol.Source), strings.ToLower(rcol.Name)}] = tableCol{ + strings.ToLower(lcol.Source), strings.ToLower(lcol.Name), } - } - if len(conditions) == 0 { - return plan.NewCrossJoin(join.Left, join.Right), nil - } - - for i, col := range rightSchema { - if _, ok := seen[col.Name]; !ok { - right = append( - right, + conditions = append( + conditions, + expression.NewEquals( + leftCol, expression.NewGetFieldWithTable( - len(leftSchema)+i, - col.Type, - col.Source, - col.Name, - col.Nullable, + len(leftSchema)+idx, + rcol.Type, + rcol.Source, + rcol.Name, + rcol.Nullable, ), - ) - } - } - - projections := append(append(common, left...), right...) - - tj := &transformedJoin{ - node: plan.NewProject( - projections, - plan.NewInnerJoin( - join.Left, - join.Right, - expression.JoinAnd(conditions...), ), - ), - condCols: colsToUnresolve, + ) + } else { + left = append(left, leftCol) } - - transformed = append(transformed, tj) - - return tj.node, nil - }) - - if err != nil || len(transformed) == 0 { - return node, err } - var transformedSeen bool - return node.TransformUp(func(node sql.Node) (sql.Node, error) { - if ok, _ := isTransformedNode(node, transformed); ok { - transformedSeen = true - return node, nil - } - - if !transformedSeen { - return node, nil - } - - expressioner, ok := node.(sql.Expressioner) - if !ok { - return node, nil - } - - return expressioner.TransformExpressions(func(e sql.Expression) (sql.Expression, error) { - var col, table string - switch e := e.(type) { - case *expression.GetField: - col, table = e.Name(), e.Table() - case *expression.UnresolvedColumn: - col, table = e.Name(), e.Table() - default: - return e, nil - } - - sources, ok := colsToUnresolve[col] - if !ok { - return e, nil - } - - if !mustUnresolve(aliasTables, table, sources.wrong) { - return e, nil - } - - return expression.NewUnresolvedQualifiedColumn( - sources.correct, - col, - ), nil - }) - }) -} - -func isTransformedNode(node sql.Node, transformed []*transformedJoin) (is bool, colsToUnresolve map[string]*transformedSource) { - var project *plan.Project - var join *plan.InnerJoin - switch n := node.(type) { - case *plan.Project: - var ok bool - join, ok = n.Child.(*plan.InnerJoin) - if !ok { - return - } - - project = n - case *plan.InnerJoin: - join = n - - default: - return + if len(conditions) == 0 { + return plan.NewCrossJoin(n.Left, n.Right), nil } - for _, t := range transformed { - tproject, ok := t.node.(*plan.Project) - if !ok { - return - } - - tjoin, ok := tproject.Child.(*plan.InnerJoin) - if !ok { - return - } - - if project != nil && !reflect.DeepEqual(project.Projections, tproject.Projections) { - continue - } - - if reflect.DeepEqual(join.Cond, tjoin.Cond) { - is = true - colsToUnresolve = t.condCols + for i, col := range rightSchema { + source := strings.ToLower(col.Source) + name := strings.ToLower(col.Name) + if _, ok := replacements[tableCol{source, name}]; !ok { + right = append( + right, + expression.NewGetFieldWithTable( + len(leftSchema)+i, + col.Type, + col.Source, + col.Name, + col.Nullable, + ), + ) } } - return + return plan.NewProject( + append(append(common, left...), right...), + plan.NewInnerJoin(n.Left, n.Right, expression.JoinAnd(conditions...)), + ), nil } -func mustUnresolve(aliasTable map[string][]string, table string, wrongSources []string) bool { - return isIn(table, wrongSources) || isAliasFor(aliasTable, table, wrongSources) -} - -func isIn(s string, l []string) bool { - for _, e := range l { - if s == e { - return true +func findCol(s sql.Schema, name string) (int, *sql.Column) { + for i, c := range s { + if strings.ToLower(c.Name) == strings.ToLower(name) { + return i, c } } - - return false + return -1, nil } -func isAliasFor(aliasTable map[string][]string, table string, wrongSources []string) bool { - tables, ok := aliasTable[table] - if !ok { - return false - } +func replaceExpressions( + n sql.Expressioner, + replacements map[tableCol]tableCol, + tableAliases map[string]string, +) (sql.Node, error) { + return n.TransformExpressions(func(e sql.Expression) (sql.Expression, error) { + switch e := e.(type) { + case *expression.GetField, *expression.UnresolvedColumn: + var tableName = e.(sql.Tableable).Table() + if t, ok := tableAliases[strings.ToLower(tableName)]; ok { + tableName = t + } - for _, t := range tables { - if isIn(t, wrongSources) { - return true + name := e.(sql.Nameable).Name() + if col, ok := replacements[tableCol{strings.ToLower(tableName), strings.ToLower(name)}]; ok { + return expression.NewUnresolvedQualifiedColumn(col.table, col.col), nil + } } - } - - return false + return e, nil + }) } diff --git a/sql/plan/exchange.go b/sql/plan/exchange.go index 18c7ac06f..9c4347369 100644 --- a/sql/plan/exchange.go +++ b/sql/plan/exchange.go @@ -6,8 +6,8 @@ import ( "io" "sync" - errors "gopkg.in/src-d/go-errors.v1" "github.com/src-d/go-mysql-server/sql" + errors "gopkg.in/src-d/go-errors.v1" ) // ErrNoPartitionable is returned when no Partitionable node is found @@ -280,14 +280,6 @@ func (it *exchangeRowIter) Close() (err error) { it.quit = nil } - // TODO(kuba): in my opinion we should close err channel here, - // but becasue we use it in another go routine, I'll leave this block commented. - // - // if it.err != nil { - // close(it.err) - // it.err = nil - // } - if it.partitions != nil { err = it.partitions.Close() }