Skip to content
This repository has been archived by the owner on Jan 28, 2021. It is now read-only.

Commit

Permalink
sql/analyzer: refactor and simplify resolve_columns rule (#714)
Browse files Browse the repository at this point in the history
sql/analyzer: refactor and simplify resolve_columns rule
  • Loading branch information
ajnavarro committed May 20, 2019
2 parents 8ae4350 + f0832ac commit 3f7b3fe
Show file tree
Hide file tree
Showing 3 changed files with 73 additions and 115 deletions.
5 changes: 0 additions & 5 deletions sql/analyzer/convert_dates.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,11 +9,6 @@ import (
"gopkg.in/src-d/go-mysql-server.v0/sql/plan"
)

type tableCol struct {
table string
col string
}

// convertDates wraps all expressions of date and datetime type with converts
// to ensure the date range is validated.
func convertDates(ctx *sql.Context, a *Analyzer, n sql.Node) (sql.Node, error) {
Expand Down
11 changes: 3 additions & 8 deletions sql/analyzer/prune_columns.go
Original file line number Diff line number Diff line change
Expand Up @@ -154,11 +154,6 @@ func pruneUnusedColumns(n sql.Node, columns usedColumns) (sql.Node, error) {
})
}

type tableColumnPair struct {
table string
column string
}

func fixRemainingFieldsIndexes(n sql.Node) (sql.Node, error) {
return n.TransformUp(func(n sql.Node) (sql.Node, error) {
switch n := n.(type) {
Expand All @@ -184,9 +179,9 @@ func fixRemainingFieldsIndexes(n sql.Node) (sql.Node, error) {
return n, nil
}

indexes := make(map[tableColumnPair]int)
indexes := make(map[tableCol]int)
for i, col := range schema {
indexes[tableColumnPair{col.Source, col.Name}] = i
indexes[tableCol{col.Source, col.Name}] = i
}

return exp.TransformExpressions(func(e sql.Expression) (sql.Expression, error) {
Expand All @@ -195,7 +190,7 @@ func fixRemainingFieldsIndexes(n sql.Node) (sql.Node, error) {
return e, nil
}

idx, ok := indexes[tableColumnPair{gf.Table(), gf.Name()}]
idx, ok := indexes[tableCol{gf.Table(), gf.Name()}]
if !ok {
return nil, ErrColumnTableNotFound.New(gf.Table(), gf.Name())
}
Expand Down
172 changes: 70 additions & 102 deletions sql/analyzer/resolve_columns.go
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,16 @@ func (e deferredColumn) TransformUp(fn sql.TransformExprFunc) (sql.Expression, e
return fn(e)
}

type tableCol struct {
table string
col string
}

type indexedCol struct {
*sql.Column
index int
}

// column is the common interface that groups UnresolvedColumn and deferredColumn.
type column interface {
sql.Nameable
Expand Down Expand Up @@ -285,136 +295,94 @@ func resolveColumns(ctx *sql.Context, a *Analyzer, n sql.Node) (sql.Node, error)
return n, nil
}

var childSchema sql.Schema
colMap := make(map[string][]*sql.Column)
for _, child := range n.Children() {
if !child.Resolved() {
return n, nil
}

for _, col := range child.Schema() {
name := strings.ToLower(col.Name)
colMap[name] = append(colMap[name], col)
childSchema = append(childSchema, col)
}
}

expressioner, ok := n.(sql.Expressioner)
if !ok {
return n, nil
}

// make sure all children are resolved before resolving a node
// We need to use the schema, so all children must be resolved.
for _, c := range n.Children() {
if !c.Resolved() {
a.Log("a children with type %T of node %T were not resolved, skipping", c, n)
return n, nil
}
}

columns := findChildIndexedColumns(n)
return expressioner.TransformExpressions(func(e sql.Expression) (sql.Expression, error) {
a.Log("transforming expression of type: %T", e)
if e.Resolved() {
return e, nil
}

uc, ok := e.(column)
if !ok {
if !ok || e.Resolved() {
return e, nil
}

name := strings.ToLower(uc.Name())
table := strings.ToLower(uc.Table())

// First of all, try to find the field in the child schema, which
// will resolve aliases.
if idx := childSchema.IndexOf(name, table); idx >= 0 {
col := childSchema[idx]
return expression.NewGetFieldWithTable(idx, col.Type, col.Source, col.Name, col.Nullable), nil
if isGlobalOrSessionColumn(uc) {
return resolveGlobalOrSessionColumn(ctx, uc)
}

columns, ok := colMap[name]
if !ok {
switch uc := uc.(type) {
case *expression.UnresolvedColumn:
if isGlobalOrSessionColumn(uc) {
if table != "" && table != sessionTable {
return nil, errGlobalVariablesNotSupported.New(uc)
}

name := strings.TrimLeft(uc.Name(), "@")
name = strings.TrimPrefix(strings.TrimPrefix(name, globalPrefix), sessionPrefix)
typ, value := ctx.Get(name)
return expression.NewGetSessionField(name, typ, value), nil
}

a.Log("evaluation of column %q was deferred", uc.Name())
return &deferredColumn{uc}, nil

default:
if len(colMap) == 0 {
return nil, ErrColumnNotFound.New(uc.Name())
}

if table != "" {
return nil, ErrColumnTableNotFound.New(uc.Table(), uc.Name())
}

similar := similartext.FindFromMap(colMap, uc.Name())
return nil, ErrColumnNotFound.New(uc.Name() + similar)
}
}
return resolveColumnExpression(ctx, uc, columns)
})
})
}

var col *sql.Column
var found bool
for _, c := range columns {
_, ok := n.(*plan.GroupBy)
if ok || (strings.ToLower(c.Source) == table) {
col = c
found = true
break
}
}
func findChildIndexedColumns(n sql.Node) map[tableCol]indexedCol {
var idx int
var columns = make(map[tableCol]indexedCol)

for _, child := range n.Children() {
for _, col := range child.Schema() {
columns[tableCol{
table: strings.ToLower(col.Source),
col: strings.ToLower(col.Name),
}] = indexedCol{col, idx}
idx++
}
}

if !found {
if table != "" {
return nil, ErrColumnTableNotFound.New(uc.Table(), uc.Name())
}
return columns
}

switch uc := uc.(type) {
case *expression.UnresolvedColumn:
return &deferredColumn{uc}, nil
default:
return nil, ErrColumnNotFound.New(uc.Name())
}
}
func resolveGlobalOrSessionColumn(ctx *sql.Context, col column) (sql.Expression, error) {
if col.Table() != "" && strings.ToLower(col.Table()) != sessionTable {
return nil, errGlobalVariablesNotSupported.New(col)
}

var schema sql.Schema
// If expressioner and unary node we must take the
// child's schema to correctly select the indexes
// in the row is going to be evaluated in this node
if plan.IsUnary(n) {
schema = n.Children()[0].Schema()
} else {
schema = n.Schema()
}
name := strings.TrimLeft(col.Name(), "@")
name = strings.TrimPrefix(strings.TrimPrefix(name, globalPrefix), sessionPrefix)
typ, value := ctx.Get(name)
return expression.NewGetSessionField(name, typ, value), nil
}

idx := schema.IndexOf(col.Name, col.Source)
if idx < 0 {
return nil, ErrColumnNotFound.New(col.Name)
func resolveColumnExpression(
ctx *sql.Context,
e column,
columns map[tableCol]indexedCol,
) (sql.Expression, error) {
name := strings.ToLower(e.Name())
table := strings.ToLower(e.Table())
col, ok := columns[tableCol{table, name}]
if !ok {
switch uc := e.(type) {
case *expression.UnresolvedColumn:
// Defer the resolution of the column to give the analyzer more
// time to resolve other parts so this can be resolved.
return &deferredColumn{uc}, nil
default:
if table != "" {
return nil, ErrColumnTableNotFound.New(e.Table(), e.Name())
}

a.Log("column resolved to %q.%q", col.Source, col.Name)
return nil, ErrColumnNotFound.New(e.Name())
}
}

return expression.NewGetFieldWithTable(
idx,
col.Type,
col.Source,
col.Name,
col.Nullable,
), nil
})
})
return expression.NewGetFieldWithTable(
col.index,
col.Type,
col.Source,
col.Name,
col.Nullable,
), nil
}

// resolveGroupingColumns reorders the aggregation in a groupby so aliases
Expand Down

0 comments on commit 3f7b3fe

Please sign in to comment.