From f0832ac3000e85b86c64d9212f42e697cbec6104 Mon Sep 17 00:00:00 2001 From: Miguel Molina Date: Fri, 17 May 2019 15:54:47 +0200 Subject: [PATCH] sql/analyzer: refactor and simplify resolve_columns rule Signed-off-by: Miguel Molina --- sql/analyzer/convert_dates.go | 5 - sql/analyzer/prune_columns.go | 11 +- sql/analyzer/resolve_columns.go | 172 +++++++++++++------------------- 3 files changed, 73 insertions(+), 115 deletions(-) diff --git a/sql/analyzer/convert_dates.go b/sql/analyzer/convert_dates.go index 10bddeb34..326c1cbfc 100644 --- a/sql/analyzer/convert_dates.go +++ b/sql/analyzer/convert_dates.go @@ -9,11 +9,6 @@ import ( "gopkg.in/src-d/go-mysql-server.v0/sql/plan" ) -type tableCol struct { - table string - col string -} - // convertDates wraps all expressions of date and datetime type with converts // to ensure the date range is validated. func convertDates(ctx *sql.Context, a *Analyzer, n sql.Node) (sql.Node, error) { diff --git a/sql/analyzer/prune_columns.go b/sql/analyzer/prune_columns.go index 1bbc65932..008c4ccd7 100644 --- a/sql/analyzer/prune_columns.go +++ b/sql/analyzer/prune_columns.go @@ -154,11 +154,6 @@ func pruneUnusedColumns(n sql.Node, columns usedColumns) (sql.Node, error) { }) } -type tableColumnPair struct { - table string - column string -} - func fixRemainingFieldsIndexes(n sql.Node) (sql.Node, error) { return n.TransformUp(func(n sql.Node) (sql.Node, error) { switch n := n.(type) { @@ -184,9 +179,9 @@ func fixRemainingFieldsIndexes(n sql.Node) (sql.Node, error) { return n, nil } - indexes := make(map[tableColumnPair]int) + indexes := make(map[tableCol]int) for i, col := range schema { - indexes[tableColumnPair{col.Source, col.Name}] = i + indexes[tableCol{col.Source, col.Name}] = i } return exp.TransformExpressions(func(e sql.Expression) (sql.Expression, error) { @@ -195,7 +190,7 @@ func fixRemainingFieldsIndexes(n sql.Node) (sql.Node, error) { return e, nil } - idx, ok := indexes[tableColumnPair{gf.Table(), gf.Name()}] + idx, ok := indexes[tableCol{gf.Table(), gf.Name()}] if !ok { return nil, ErrColumnTableNotFound.New(gf.Table(), gf.Name()) } diff --git a/sql/analyzer/resolve_columns.go b/sql/analyzer/resolve_columns.go index ddb84cdef..1fa006e07 100644 --- a/sql/analyzer/resolve_columns.go +++ b/sql/analyzer/resolve_columns.go @@ -102,6 +102,16 @@ func (e deferredColumn) TransformUp(fn sql.TransformExprFunc) (sql.Expression, e return fn(e) } +type tableCol struct { + table string + col string +} + +type indexedCol struct { + *sql.Column + index int +} + // column is the common interface that groups UnresolvedColumn and deferredColumn. type column interface { sql.Nameable @@ -285,136 +295,94 @@ func resolveColumns(ctx *sql.Context, a *Analyzer, n sql.Node) (sql.Node, error) return n, nil } - var childSchema sql.Schema - colMap := make(map[string][]*sql.Column) - for _, child := range n.Children() { - if !child.Resolved() { - return n, nil - } - - for _, col := range child.Schema() { - name := strings.ToLower(col.Name) - colMap[name] = append(colMap[name], col) - childSchema = append(childSchema, col) - } - } - expressioner, ok := n.(sql.Expressioner) if !ok { return n, nil } - // make sure all children are resolved before resolving a node + // We need to use the schema, so all children must be resolved. for _, c := range n.Children() { if !c.Resolved() { - a.Log("a children with type %T of node %T were not resolved, skipping", c, n) return n, nil } } + columns := findChildIndexedColumns(n) return expressioner.TransformExpressions(func(e sql.Expression) (sql.Expression, error) { a.Log("transforming expression of type: %T", e) - if e.Resolved() { - return e, nil - } uc, ok := e.(column) - if !ok { + if !ok || e.Resolved() { return e, nil } - name := strings.ToLower(uc.Name()) - table := strings.ToLower(uc.Table()) - - // First of all, try to find the field in the child schema, which - // will resolve aliases. - if idx := childSchema.IndexOf(name, table); idx >= 0 { - col := childSchema[idx] - return expression.NewGetFieldWithTable(idx, col.Type, col.Source, col.Name, col.Nullable), nil + if isGlobalOrSessionColumn(uc) { + return resolveGlobalOrSessionColumn(ctx, uc) } - columns, ok := colMap[name] - if !ok { - switch uc := uc.(type) { - case *expression.UnresolvedColumn: - if isGlobalOrSessionColumn(uc) { - if table != "" && table != sessionTable { - return nil, errGlobalVariablesNotSupported.New(uc) - } - - name := strings.TrimLeft(uc.Name(), "@") - name = strings.TrimPrefix(strings.TrimPrefix(name, globalPrefix), sessionPrefix) - typ, value := ctx.Get(name) - return expression.NewGetSessionField(name, typ, value), nil - } - - a.Log("evaluation of column %q was deferred", uc.Name()) - return &deferredColumn{uc}, nil - - default: - if len(colMap) == 0 { - return nil, ErrColumnNotFound.New(uc.Name()) - } - - if table != "" { - return nil, ErrColumnTableNotFound.New(uc.Table(), uc.Name()) - } - - similar := similartext.FindFromMap(colMap, uc.Name()) - return nil, ErrColumnNotFound.New(uc.Name() + similar) - } - } + return resolveColumnExpression(ctx, uc, columns) + }) + }) +} - var col *sql.Column - var found bool - for _, c := range columns { - _, ok := n.(*plan.GroupBy) - if ok || (strings.ToLower(c.Source) == table) { - col = c - found = true - break - } - } +func findChildIndexedColumns(n sql.Node) map[tableCol]indexedCol { + var idx int + var columns = make(map[tableCol]indexedCol) + + for _, child := range n.Children() { + for _, col := range child.Schema() { + columns[tableCol{ + table: strings.ToLower(col.Source), + col: strings.ToLower(col.Name), + }] = indexedCol{col, idx} + idx++ + } + } - if !found { - if table != "" { - return nil, ErrColumnTableNotFound.New(uc.Table(), uc.Name()) - } + return columns +} - switch uc := uc.(type) { - case *expression.UnresolvedColumn: - return &deferredColumn{uc}, nil - default: - return nil, ErrColumnNotFound.New(uc.Name()) - } - } +func resolveGlobalOrSessionColumn(ctx *sql.Context, col column) (sql.Expression, error) { + if col.Table() != "" && strings.ToLower(col.Table()) != sessionTable { + return nil, errGlobalVariablesNotSupported.New(col) + } - var schema sql.Schema - // If expressioner and unary node we must take the - // child's schema to correctly select the indexes - // in the row is going to be evaluated in this node - if plan.IsUnary(n) { - schema = n.Children()[0].Schema() - } else { - schema = n.Schema() - } + name := strings.TrimLeft(col.Name(), "@") + name = strings.TrimPrefix(strings.TrimPrefix(name, globalPrefix), sessionPrefix) + typ, value := ctx.Get(name) + return expression.NewGetSessionField(name, typ, value), nil +} - idx := schema.IndexOf(col.Name, col.Source) - if idx < 0 { - return nil, ErrColumnNotFound.New(col.Name) +func resolveColumnExpression( + ctx *sql.Context, + e column, + columns map[tableCol]indexedCol, +) (sql.Expression, error) { + name := strings.ToLower(e.Name()) + table := strings.ToLower(e.Table()) + col, ok := columns[tableCol{table, name}] + if !ok { + switch uc := e.(type) { + case *expression.UnresolvedColumn: + // Defer the resolution of the column to give the analyzer more + // time to resolve other parts so this can be resolved. + return &deferredColumn{uc}, nil + default: + if table != "" { + return nil, ErrColumnTableNotFound.New(e.Table(), e.Name()) } - a.Log("column resolved to %q.%q", col.Source, col.Name) + return nil, ErrColumnNotFound.New(e.Name()) + } + } - return expression.NewGetFieldWithTable( - idx, - col.Type, - col.Source, - col.Name, - col.Nullable, - ), nil - }) - }) + return expression.NewGetFieldWithTable( + col.index, + col.Type, + col.Source, + col.Name, + col.Nullable, + ), nil } // resolveGroupingColumns reorders the aggregation in a groupby so aliases