Skip to content
This repository has been archived by the owner on Jan 28, 2021. It is now read-only.

Commit

Permalink
Merge pull request #633 from erizocosmico/fix/reorder-error
Browse files Browse the repository at this point in the history
allow all expressions in grouping, resolve orderby expressions
  • Loading branch information
ajnavarro committed Mar 12, 2019
2 parents 08e98ce + 52476d6 commit b829206
Show file tree
Hide file tree
Showing 4 changed files with 44 additions and 21 deletions.
8 changes: 8 additions & 0 deletions engine_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -868,6 +868,14 @@ var queries = []struct {
"ROLLBACK",
[]sql.Row{},
},
{
"SELECT substring(s, 1, 1) FROM mytable ORDER BY substring(s, 1, 1)",
[]sql.Row{{"f"}, {"s"}, {"t"}},
},
{
"SELECT substring(s, 1, 1), count(*) FROM mytable GROUP BY substring(s, 1, 1)",
[]sql.Row{{"f", int32(1)}, {"s", int32(1)}, {"t", int32(1)}},
},
}

func TestQueries(t *testing.T) {
Expand Down
28 changes: 20 additions & 8 deletions sql/analyzer/resolve_orderby.go
Original file line number Diff line number Diff line change
Expand Up @@ -35,15 +35,14 @@ func resolveOrderBy(ctx *sql.Context, a *Analyzer, n sql.Node) (sql.Node, error)
var colsFromChild []string
var missingCols []string
for _, f := range sort.SortFields {
n, ok := f.Column.(sql.Nameable)
if !ok {
continue
}
ns := findExprNameables(f.Column)

if stringContains(childNewCols, n.Name()) {
colsFromChild = append(colsFromChild, n.Name())
} else if !stringContains(schemaCols, n.Name()) {
missingCols = append(missingCols, n.Name())
for _, n := range ns {
if stringContains(childNewCols, n.Name()) {
colsFromChild = append(colsFromChild, n.Name())
} else if !stringContains(schemaCols, n.Name()) {
missingCols = append(missingCols, n.Name())
}
}
}

Expand Down Expand Up @@ -221,3 +220,16 @@ func resolveOrderByLiterals(ctx *sql.Context, a *Analyzer, n sql.Node) (sql.Node
return plan.NewSort(fields, sort.Child), nil
})
}

func findExprNameables(e sql.Expression) []sql.Nameable {
var result []sql.Nameable
expression.Inspect(e, func(e sql.Expression) bool {
n, ok := e.(sql.Nameable)
if ok {
result = append(result, n)
return false
}
return true
})
return result
}
8 changes: 2 additions & 6 deletions sql/plan/group_by.go
Original file line number Diff line number Diff line change
Expand Up @@ -353,15 +353,13 @@ func updateBuffer(
return n.Update(ctx, buffers[idx], row)
case *expression.Alias:
return updateBuffer(ctx, buffers, idx, n.Child, row)
case *expression.GetField:
default:
val, err := expr.Eval(ctx, row)
if err != nil {
return err
}
buffers[idx] = sql.NewRow(val)
return nil
default:
return ErrGroupBy.New(n.String())
}
}

Expand Down Expand Up @@ -393,12 +391,10 @@ func evalBuffer(
return n.Eval(ctx, buffer)
case *expression.Alias:
return evalBuffer(ctx, n.Child, buffer)
case *expression.GetField:
default:
if len(buffer) > 0 {
return buffer[0], nil
}
return nil, nil
default:
return nil, ErrGroupBy.New(n.String())
}
}
21 changes: 14 additions & 7 deletions sql/plan/group_by_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ import (
"gopkg.in/src-d/go-mysql-server.v0/sql/expression/function/aggregation"
)

func TestGroupBy_Schema(t *testing.T) {
func TestGroupBySchema(t *testing.T) {
require := require.New(t)

child := mem.NewTable("test", nil)
Expand All @@ -25,7 +25,7 @@ func TestGroupBy_Schema(t *testing.T) {
}, gb.Schema())
}

func TestGroupBy_Resolved(t *testing.T) {
func TestGroupByResolved(t *testing.T) {
require := require.New(t)

child := mem.NewTable("test", nil)
Expand All @@ -42,7 +42,7 @@ func TestGroupBy_Resolved(t *testing.T) {
require.False(gb.Resolved())
}

func TestGroupBy_RowIter(t *testing.T) {
func TestGroupByRowIter(t *testing.T) {
require := require.New(t)
ctx := sql.NewEmptyContext()

Expand Down Expand Up @@ -96,7 +96,7 @@ func TestGroupBy_RowIter(t *testing.T) {
require.Equal(sql.NewRow("col1_2", int64(4444)), rows[1])
}

func TestGroupBy_EvalEmptyBuffer(t *testing.T) {
func TestGroupByEvalEmptyBuffer(t *testing.T) {
require := require.New(t)
ctx := sql.NewEmptyContext()

Expand All @@ -105,7 +105,7 @@ func TestGroupBy_EvalEmptyBuffer(t *testing.T) {
require.Nil(r)
}

func TestGroupBy_Error(t *testing.T) {
func TestGroupByAggregationGrouping(t *testing.T) {
require := require.New(t)
ctx := sql.NewEmptyContext()

Expand Down Expand Up @@ -140,8 +140,15 @@ func TestGroupBy_Error(t *testing.T) {
NewResolvedTable(child),
)

_, err := sql.NodeToRows(ctx, p)
require.Error(err)
rows, err := sql.NodeToRows(ctx, p)
require.NoError(err)

expected := []sql.Row{
{int32(3), false},
{int32(2), false},
}

require.Equal(expected, rows)
}

func BenchmarkGroupBy(b *testing.B) {
Expand Down

0 comments on commit b829206

Please sign in to comment.