Skip to content
This repository has been archived by the owner on Jan 28, 2021. It is now read-only.

sql/analyzer: refactor and fix bugs in qualify_columns rule #706

Merged
merged 1 commit into from
May 13, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions engine_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1060,6 +1060,10 @@ var queries = []struct {
`SELECT t.date_col FROM (SELECT CONVERT('2019-06-06 00:00:00', DATETIME) as date_col) t GROUP BY t.date_col`,
[]sql.Row{{time.Date(2019, time.June, 6, 0, 0, 0, 0, time.UTC)}},
},
{
`SELECT i AS foo FROM mytable ORDER BY mytable.i`,
[]sql.Row{{int64(1)}, {int64(2)}, {int64(3)}},
},
}

func TestQueries(t *testing.T) {
Expand Down
310 changes: 119 additions & 191 deletions sql/analyzer/resolve_columns.go
Original file line number Diff line number Diff line change
Expand Up @@ -110,228 +110,156 @@ type column interface {
}

func qualifyColumns(ctx *sql.Context, a *Analyzer, n sql.Node) (sql.Node, error) {
span, _ := ctx.Span("qualify_columns")
defer span.Finish()

a.Log("qualify columns")
tables := make(map[string]sql.Node)
tableAliases := make(map[string]string)
colIndex := make(map[string][]string)

indexCols := func(table string, schema sql.Schema) {
for _, col := range schema {
name := strings.ToLower(col.Name)
colIndex[name] = append(colIndex[name], strings.ToLower(table))
}
}

var projects, seenProjects int
plan.Inspect(n, func(n sql.Node) bool {
if _, ok := n.(*plan.Project); ok {
projects++
}
return true
})

return n.TransformUp(func(n sql.Node) (sql.Node, error) {
a.Log("transforming node of type: %T", n)
switch n := n.(type) {
case *plan.TableAlias:
switch t := n.Child.(type) {
case *plan.ResolvedTable, *plan.UnresolvedTable:
name := strings.ToLower(t.(sql.Nameable).Name())
tableAliases[strings.ToLower(n.Name())] = name
default:
tables[strings.ToLower(n.Name())] = n.Child
indexCols(n.Name(), n.Schema())
}
case *plan.ResolvedTable, *plan.SubqueryAlias:
name := strings.ToLower(n.(sql.Nameable).Name())
tables[name] = n
indexCols(name, n.Schema())
}

exp, ok := n.(sql.Expressioner)
if !ok {
if !ok || n.Resolved() {
return n, nil
}

result, err := exp.TransformExpressions(func(e sql.Expression) (sql.Expression, error) {
a.Log("transforming expression of type: %T", e)
switch col := e.(type) {
case *expression.UnresolvedColumn:
// Skip this step for global and session variables
if isGlobalOrSessionColumn(col) {
return col, nil
}
columns := getNodeAvailableColumns(n)
tables := getNodeAvailableTables(n)

col = expression.NewUnresolvedQualifiedColumn(col.Table(), col.Name())
name := strings.ToLower(col.Name())
table := strings.ToLower(col.Table())
if table == "" {
// If a column has no table, it might be an alias
// defined in a child projection, so check that instead
// of incorrectly qualify it.
if isDefinedInChildProject(n, col) {
return col, nil
}

tables := dedupStrings(colIndex[name])
switch len(tables) {
case 0:
// If there are no tables that have any column with the column
// name let's just return it as it is. This may be an alias, so
// we'll wait for the reorder of the projection.
return col, nil
case 1:
col = expression.NewUnresolvedQualifiedColumn(
tables[0],
col.Name(),
)
default:
if _, ok := n.(*plan.GroupBy); ok {
return expression.NewUnresolvedColumn(col.Name()), nil
}
return nil, ErrAmbiguousColumnName.New(col.Name(), strings.Join(tables, ", "))
}
} else {
if real, ok := tableAliases[table]; ok {
col = expression.NewUnresolvedQualifiedColumn(
real,
col.Name(),
)
}
return exp.TransformExpressions(func(e sql.Expression) (sql.Expression, error) {
return qualifyExpression(e, columns, tables)
})
})
}

if _, ok := tables[col.Table()]; !ok {
if len(tables) == 0 {
return nil, sql.ErrTableNotFound.New(col.Table())
}
func qualifyExpression(
e sql.Expression,
columns map[string][]string,
tables map[string]string,
) (sql.Expression, error) {
switch col := e.(type) {
case column:
// Skip this step for global and session variables
if isGlobalOrSessionColumn(col) {
return col, nil
}

similar := similartext.FindFromMap(tables, col.Table())
return nil, sql.ErrTableNotFound.New(col.Table() + similar)
}
name, table := strings.ToLower(col.Name()), strings.ToLower(col.Table())
availableTables := dedupStrings(columns[name])
if table != "" {
table, ok := tables[table]
if !ok {
if len(tables) == 0 {
return nil, sql.ErrTableNotFound.New(col.Table())
}

a.Log("column %q was qualified with table %q", col.Name(), col.Table())
similar := similartext.FindFromMap(tables, col.Table())
return nil, sql.ErrTableNotFound.New(col.Table() + similar)
}

// If the table exists but it's not available for this node it
// means some work is still needed, so just return the column
// and let it be resolved in the next pass.
if !stringContains(availableTables, table) {
return col, nil
case *expression.Star:
if col.Table != "" {
if real, ok := tableAliases[strings.ToLower(col.Table)]; ok {
col = expression.NewQualifiedStar(real)
}
}

if _, ok := tables[strings.ToLower(col.Table)]; !ok {
return nil, sql.ErrTableNotFound.New(col.Table)
}
return expression.NewUnresolvedQualifiedColumn(table, col.Name()), nil
}

return col, nil
}
default:
// If any other kind of expression has a star, just replace it
// with an unqualified star because it cannot be expanded.
return e.TransformUp(func(e sql.Expression) (sql.Expression, error) {
if _, ok := e.(*expression.Star); ok {
return expression.NewStar(), nil
}
return e, nil
})
switch len(availableTables) {
case 0:
// If there are no tables that have any column with the column
// name let's just return it as it is. This may be an alias, so
// we'll wait for the reorder of the projection.
return col, nil
case 1:
return expression.NewUnresolvedQualifiedColumn(
availableTables[0],
col.Name(),
), nil
default:
return nil, ErrAmbiguousColumnName.New(col.Name(), strings.Join(availableTables, ", "))
}
case *expression.Star:
if col.Table != "" {
if real, ok := tables[strings.ToLower(col.Table)]; ok {
col = expression.NewQualifiedStar(real)
}

if _, ok := tables[strings.ToLower(col.Table)]; !ok {
return nil, sql.ErrTableNotFound.New(col.Table)
}
}
return col, nil
default:
// If any other kind of expression has a star, just replace it
// with an unqualified star because it cannot be expanded.
return e.TransformUp(func(e sql.Expression) (sql.Expression, error) {
if _, ok := e.(*expression.Star); ok {
return expression.NewStar(), nil
}
return e, nil
})
}
}

if err != nil {
return nil, err
}
func getNodeAvailableColumns(n sql.Node) map[string][]string {
var columns = make(map[string][]string)
getColumnsInNodes(n.Children(), columns)
return columns
}

// We should ignore the topmost project, because some nodes are
// reordered, such as Sort, and they would not be resolved well.
if n, ok := result.(*plan.Project); ok && projects-seenProjects > 1 {
seenProjects++

// We need to modify the indexed columns to only contain what is
// projected in this project. If the column is not qualified by any
// table, just keep the ones that are currently in the index.
// If it is, then just make those tables available for the column.
// If we don't do this, columns that are not projected will be
// available in this step and may cause false errors or unintended
// results.
var projected = make(map[string][]string)
for _, p := range n.Projections {
var table, col string
switch p := p.(type) {
case column:
table = p.Table()
col = p.Name()
default:
continue
}
func getColumnsInNodes(nodes []sql.Node, columns map[string][]string) {
indexCol := func(table, col string) {
col = strings.ToLower(col)
columns[col] = append(columns[col], strings.ToLower(table))
}

col = strings.ToLower(col)
table = strings.ToLower(table)
if table != "" {
projected[col] = append(projected[col], table)
} else {
projected[col] = append(projected[col], colIndex[col]...)
}
indexExpressions := func(exprs []sql.Expression) {
for _, e := range exprs {
switch e := e.(type) {
case *expression.Alias:
indexCol("", e.Name())
case *expression.GetField:
indexCol(e.Table(), e.Name())
case *expression.UnresolvedColumn:
indexCol(e.Table(), e.Name())
}
}
}

colIndex = make(map[string][]string)
for col, tables := range projected {
colIndex[col] = dedupStrings(tables)
for _, node := range nodes {
switch n := node.(type) {
case *plan.ResolvedTable, *plan.SubqueryAlias:
for _, col := range n.Schema() {
indexCol(col.Source, col.Name)
}
case *plan.Project:
indexExpressions(n.Projections)
case *plan.GroupBy:
indexExpressions(n.Aggregate)
default:
getColumnsInNodes(n.Children(), columns)
}

return result, nil
})
}
}

func isDefinedInChildProject(n sql.Node, col *expression.UnresolvedColumn) bool {
var x sql.Node
for _, child := range n.Children() {
plan.Inspect(child, func(n sql.Node) bool {
func getNodeAvailableTables(n sql.Node) map[string]string {
var tables = make(map[string]string)
for _, c := range n.Children() {
plan.Inspect(c, func(n sql.Node) bool {
switch n := n.(type) {
case *plan.SubqueryAlias:
case *plan.SubqueryAlias, *plan.ResolvedTable:
name := strings.ToLower(n.(sql.Nameable).Name())
tables[name] = name
return false
case *plan.Project, *plan.GroupBy:
if x == nil {
x = n
case *plan.TableAlias:
switch t := n.Child.(type) {
case *plan.ResolvedTable, *plan.UnresolvedTable:
name := strings.ToLower(t.(sql.Nameable).Name())
alias := strings.ToLower(n.Name())
tables[alias] = name
}
return false
default:
return true
}
})

if x != nil {
break
}
}

if x == nil {
return false
}

var found bool
for _, expr := range x.(sql.Expressioner).Expressions() {
switch expr := expr.(type) {
case *expression.Alias:
if strings.ToLower(expr.Name()) == strings.ToLower(col.Name()) {
found = true
}
case column:
if strings.ToLower(expr.Name()) == strings.ToLower(col.Name()) &&
strings.ToLower(expr.Table()) == strings.ToLower(col.Table()) {
found = true
}
}

if found {
break
}
return true
})
}

return found
return tables
}

var errGlobalVariablesNotSupported = errors.NewKind("can't resolve global variable, %s was requested")
Expand Down Expand Up @@ -659,6 +587,6 @@ func dedupStrings(in []string) []string {
return result
}

func isGlobalOrSessionColumn(col *expression.UnresolvedColumn) bool {
func isGlobalOrSessionColumn(col column) bool {
return strings.HasPrefix(col.Name(), "@@") || strings.HasPrefix(col.Table(), "@@")
}
8 changes: 4 additions & 4 deletions sql/analyzer/resolve_columns_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -79,10 +79,10 @@ func TestQualifyColumns(t *testing.T) {
require := require.New(t)
f := getRule("qualify_columns")

table := mem.NewTable("mytable", sql.Schema{{Name: "i", Type: sql.Int32}})
table2 := mem.NewTable("mytable2", sql.Schema{{Name: "i", Type: sql.Int32}})
sessionTable := mem.NewTable("@@session", sql.Schema{{Name: "autocommit", Type: sql.Int64}})
globalTable := mem.NewTable("@@global", sql.Schema{{Name: "max_allowed_packet", Type: sql.Int64}})
table := mem.NewTable("mytable", sql.Schema{{Name: "i", Type: sql.Int32, Source: "mytable"}})
table2 := mem.NewTable("mytable2", sql.Schema{{Name: "i", Type: sql.Int32, Source: "mytable2"}})
sessionTable := mem.NewTable("@@session", sql.Schema{{Name: "autocommit", Type: sql.Int64, Source: "@@session"}})
globalTable := mem.NewTable("@@global", sql.Schema{{Name: "max_allowed_packet", Type: sql.Int64, Source: "@@global"}})

node := plan.NewProject(
[]sql.Expression{
Expand Down
Loading