Skip to content
This repository has been archived by the owner on Jan 28, 2021. It is now read-only.

sql/analyzer: refactor and simplify resolve_columns rule #714

Merged
merged 1 commit into from
May 20, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 0 additions & 5 deletions sql/analyzer/convert_dates.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,11 +9,6 @@ import (
"gopkg.in/src-d/go-mysql-server.v0/sql/plan"
)

type tableCol struct {
table string
col string
}

// convertDates wraps all expressions of date and datetime type with converts
// to ensure the date range is validated.
func convertDates(ctx *sql.Context, a *Analyzer, n sql.Node) (sql.Node, error) {
Expand Down
11 changes: 3 additions & 8 deletions sql/analyzer/prune_columns.go
Original file line number Diff line number Diff line change
Expand Up @@ -154,11 +154,6 @@ func pruneUnusedColumns(n sql.Node, columns usedColumns) (sql.Node, error) {
})
}

type tableColumnPair struct {
table string
column string
}

func fixRemainingFieldsIndexes(n sql.Node) (sql.Node, error) {
return n.TransformUp(func(n sql.Node) (sql.Node, error) {
switch n := n.(type) {
Expand All @@ -184,9 +179,9 @@ func fixRemainingFieldsIndexes(n sql.Node) (sql.Node, error) {
return n, nil
}

indexes := make(map[tableColumnPair]int)
indexes := make(map[tableCol]int)
for i, col := range schema {
indexes[tableColumnPair{col.Source, col.Name}] = i
indexes[tableCol{col.Source, col.Name}] = i
}

return exp.TransformExpressions(func(e sql.Expression) (sql.Expression, error) {
Expand All @@ -195,7 +190,7 @@ func fixRemainingFieldsIndexes(n sql.Node) (sql.Node, error) {
return e, nil
}

idx, ok := indexes[tableColumnPair{gf.Table(), gf.Name()}]
idx, ok := indexes[tableCol{gf.Table(), gf.Name()}]
if !ok {
return nil, ErrColumnTableNotFound.New(gf.Table(), gf.Name())
}
Expand Down
172 changes: 70 additions & 102 deletions sql/analyzer/resolve_columns.go
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,16 @@ func (e deferredColumn) TransformUp(fn sql.TransformExprFunc) (sql.Expression, e
return fn(e)
}

type tableCol struct {
table string
col string
}

type indexedCol struct {
*sql.Column
index int
}

// column is the common interface that groups UnresolvedColumn and deferredColumn.
type column interface {
sql.Nameable
Expand Down Expand Up @@ -285,136 +295,94 @@ func resolveColumns(ctx *sql.Context, a *Analyzer, n sql.Node) (sql.Node, error)
return n, nil
}

var childSchema sql.Schema
colMap := make(map[string][]*sql.Column)
for _, child := range n.Children() {
if !child.Resolved() {
return n, nil
}

for _, col := range child.Schema() {
name := strings.ToLower(col.Name)
colMap[name] = append(colMap[name], col)
childSchema = append(childSchema, col)
}
}

expressioner, ok := n.(sql.Expressioner)
if !ok {
return n, nil
}

// make sure all children are resolved before resolving a node
// We need to use the schema, so all children must be resolved.
for _, c := range n.Children() {
if !c.Resolved() {
a.Log("a children with type %T of node %T were not resolved, skipping", c, n)
return n, nil
}
}

columns := findChildIndexedColumns(n)
return expressioner.TransformExpressions(func(e sql.Expression) (sql.Expression, error) {
a.Log("transforming expression of type: %T", e)
if e.Resolved() {
return e, nil
}

uc, ok := e.(column)
if !ok {
if !ok || e.Resolved() {
return e, nil
}

name := strings.ToLower(uc.Name())
table := strings.ToLower(uc.Table())

// First of all, try to find the field in the child schema, which
// will resolve aliases.
if idx := childSchema.IndexOf(name, table); idx >= 0 {
col := childSchema[idx]
return expression.NewGetFieldWithTable(idx, col.Type, col.Source, col.Name, col.Nullable), nil
if isGlobalOrSessionColumn(uc) {
return resolveGlobalOrSessionColumn(ctx, uc)
}

columns, ok := colMap[name]
if !ok {
switch uc := uc.(type) {
case *expression.UnresolvedColumn:
if isGlobalOrSessionColumn(uc) {
if table != "" && table != sessionTable {
return nil, errGlobalVariablesNotSupported.New(uc)
}

name := strings.TrimLeft(uc.Name(), "@")
name = strings.TrimPrefix(strings.TrimPrefix(name, globalPrefix), sessionPrefix)
typ, value := ctx.Get(name)
return expression.NewGetSessionField(name, typ, value), nil
}

a.Log("evaluation of column %q was deferred", uc.Name())
return &deferredColumn{uc}, nil

default:
if len(colMap) == 0 {
return nil, ErrColumnNotFound.New(uc.Name())
}

if table != "" {
return nil, ErrColumnTableNotFound.New(uc.Table(), uc.Name())
}

similar := similartext.FindFromMap(colMap, uc.Name())
return nil, ErrColumnNotFound.New(uc.Name() + similar)
}
}
return resolveColumnExpression(ctx, uc, columns)
})
})
}

var col *sql.Column
var found bool
for _, c := range columns {
_, ok := n.(*plan.GroupBy)
if ok || (strings.ToLower(c.Source) == table) {
col = c
found = true
break
}
}
func findChildIndexedColumns(n sql.Node) map[tableCol]indexedCol {
var idx int
var columns = make(map[tableCol]indexedCol)

for _, child := range n.Children() {
for _, col := range child.Schema() {
columns[tableCol{
table: strings.ToLower(col.Source),
col: strings.ToLower(col.Name),
}] = indexedCol{col, idx}
idx++
}
}

if !found {
if table != "" {
return nil, ErrColumnTableNotFound.New(uc.Table(), uc.Name())
}
return columns
}

switch uc := uc.(type) {
case *expression.UnresolvedColumn:
return &deferredColumn{uc}, nil
default:
return nil, ErrColumnNotFound.New(uc.Name())
}
}
func resolveGlobalOrSessionColumn(ctx *sql.Context, col column) (sql.Expression, error) {
if col.Table() != "" && strings.ToLower(col.Table()) != sessionTable {
return nil, errGlobalVariablesNotSupported.New(col)
}

var schema sql.Schema
// If expressioner and unary node we must take the
// child's schema to correctly select the indexes
// in the row is going to be evaluated in this node
if plan.IsUnary(n) {
schema = n.Children()[0].Schema()
} else {
schema = n.Schema()
}
name := strings.TrimLeft(col.Name(), "@")
name = strings.TrimPrefix(strings.TrimPrefix(name, globalPrefix), sessionPrefix)
typ, value := ctx.Get(name)
return expression.NewGetSessionField(name, typ, value), nil
}

idx := schema.IndexOf(col.Name, col.Source)
if idx < 0 {
return nil, ErrColumnNotFound.New(col.Name)
func resolveColumnExpression(
ctx *sql.Context,
e column,
columns map[tableCol]indexedCol,
) (sql.Expression, error) {
name := strings.ToLower(e.Name())
table := strings.ToLower(e.Table())
col, ok := columns[tableCol{table, name}]
if !ok {
switch uc := e.(type) {
case *expression.UnresolvedColumn:
// Defer the resolution of the column to give the analyzer more
// time to resolve other parts so this can be resolved.
return &deferredColumn{uc}, nil
default:
if table != "" {
return nil, ErrColumnTableNotFound.New(e.Table(), e.Name())
}

a.Log("column resolved to %q.%q", col.Source, col.Name)
return nil, ErrColumnNotFound.New(e.Name())
}
}

return expression.NewGetFieldWithTable(
idx,
col.Type,
col.Source,
col.Name,
col.Nullable,
), nil
})
})
return expression.NewGetFieldWithTable(
col.index,
col.Type,
col.Source,
col.Name,
col.Nullable,
), nil
}

// resolveGroupingColumns reorders the aggregation in a groupby so aliases
Expand Down