Skip to content
This repository has been archived by the owner on Jan 28, 2021. It is now read-only.

allow all expressions in grouping, resolve orderby expressions #633

Merged
merged 1 commit into from
Mar 12, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions engine_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -868,6 +868,14 @@ var queries = []struct {
"ROLLBACK",
[]sql.Row{},
},
{
"SELECT substring(s, 1, 1) FROM mytable ORDER BY substring(s, 1, 1)",
[]sql.Row{{"f"}, {"s"}, {"t"}},
},
{
"SELECT substring(s, 1, 1), count(*) FROM mytable GROUP BY substring(s, 1, 1)",
[]sql.Row{{"f", int32(1)}, {"s", int32(1)}, {"t", int32(1)}},
},
}

func TestQueries(t *testing.T) {
Expand Down
28 changes: 20 additions & 8 deletions sql/analyzer/resolve_orderby.go
Original file line number Diff line number Diff line change
Expand Up @@ -35,15 +35,14 @@ func resolveOrderBy(ctx *sql.Context, a *Analyzer, n sql.Node) (sql.Node, error)
var colsFromChild []string
var missingCols []string
for _, f := range sort.SortFields {
n, ok := f.Column.(sql.Nameable)
if !ok {
continue
}
ns := findExprNameables(f.Column)

if stringContains(childNewCols, n.Name()) {
colsFromChild = append(colsFromChild, n.Name())
} else if !stringContains(schemaCols, n.Name()) {
missingCols = append(missingCols, n.Name())
for _, n := range ns {
if stringContains(childNewCols, n.Name()) {
colsFromChild = append(colsFromChild, n.Name())
} else if !stringContains(schemaCols, n.Name()) {
missingCols = append(missingCols, n.Name())
}
}
}

Expand Down Expand Up @@ -221,3 +220,16 @@ func resolveOrderByLiterals(ctx *sql.Context, a *Analyzer, n sql.Node) (sql.Node
return plan.NewSort(fields, sort.Child), nil
})
}

func findExprNameables(e sql.Expression) []sql.Nameable {
var result []sql.Nameable
expression.Inspect(e, func(e sql.Expression) bool {
n, ok := e.(sql.Nameable)
if ok {
result = append(result, n)
return false
}
return true
})
return result
}
8 changes: 2 additions & 6 deletions sql/plan/group_by.go
Original file line number Diff line number Diff line change
Expand Up @@ -353,15 +353,13 @@ func updateBuffer(
return n.Update(ctx, buffers[idx], row)
case *expression.Alias:
return updateBuffer(ctx, buffers, idx, n.Child, row)
case *expression.GetField:
default:
val, err := expr.Eval(ctx, row)
if err != nil {
return err
}
buffers[idx] = sql.NewRow(val)
return nil
default:
return ErrGroupBy.New(n.String())
}
}

Expand Down Expand Up @@ -393,12 +391,10 @@ func evalBuffer(
return n.Eval(ctx, buffer)
case *expression.Alias:
return evalBuffer(ctx, n.Child, buffer)
case *expression.GetField:
default:
if len(buffer) > 0 {
return buffer[0], nil
}
return nil, nil
default:
return nil, ErrGroupBy.New(n.String())
}
}
21 changes: 14 additions & 7 deletions sql/plan/group_by_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ import (
"gopkg.in/src-d/go-mysql-server.v0/sql/expression/function/aggregation"
)

func TestGroupBy_Schema(t *testing.T) {
func TestGroupBySchema(t *testing.T) {
require := require.New(t)

child := mem.NewTable("test", nil)
Expand All @@ -25,7 +25,7 @@ func TestGroupBy_Schema(t *testing.T) {
}, gb.Schema())
}

func TestGroupBy_Resolved(t *testing.T) {
func TestGroupByResolved(t *testing.T) {
require := require.New(t)

child := mem.NewTable("test", nil)
Expand All @@ -42,7 +42,7 @@ func TestGroupBy_Resolved(t *testing.T) {
require.False(gb.Resolved())
}

func TestGroupBy_RowIter(t *testing.T) {
func TestGroupByRowIter(t *testing.T) {
require := require.New(t)
ctx := sql.NewEmptyContext()

Expand Down Expand Up @@ -96,7 +96,7 @@ func TestGroupBy_RowIter(t *testing.T) {
require.Equal(sql.NewRow("col1_2", int64(4444)), rows[1])
}

func TestGroupBy_EvalEmptyBuffer(t *testing.T) {
func TestGroupByEvalEmptyBuffer(t *testing.T) {
require := require.New(t)
ctx := sql.NewEmptyContext()

Expand All @@ -105,7 +105,7 @@ func TestGroupBy_EvalEmptyBuffer(t *testing.T) {
require.Nil(r)
}

func TestGroupBy_Error(t *testing.T) {
func TestGroupByAggregationGrouping(t *testing.T) {
require := require.New(t)
ctx := sql.NewEmptyContext()

Expand Down Expand Up @@ -140,8 +140,15 @@ func TestGroupBy_Error(t *testing.T) {
NewResolvedTable(child),
)

_, err := sql.NodeToRows(ctx, p)
require.Error(err)
rows, err := sql.NodeToRows(ctx, p)
require.NoError(err)

expected := []sql.Row{
{int32(3), false},
{int32(2), false},
}

require.Equal(expected, rows)
}

func BenchmarkGroupBy(b *testing.B) {
Expand Down