From 52476d654ca76a9bc2eae83d7814af33727b8037 Mon Sep 17 00:00:00 2001 From: Miguel Molina Date: Mon, 11 Mar 2019 14:22:03 +0100 Subject: [PATCH] allow all expressions in grouping, resolve orderby expressions Signed-off-by: Miguel Molina --- engine_test.go | 8 ++++++++ sql/analyzer/resolve_orderby.go | 28 ++++++++++++++++++++-------- sql/plan/group_by.go | 8 ++------ sql/plan/group_by_test.go | 21 ++++++++++++++------- 4 files changed, 44 insertions(+), 21 deletions(-) diff --git a/engine_test.go b/engine_test.go index 3f4da975f..0bcae601a 100644 --- a/engine_test.go +++ b/engine_test.go @@ -868,6 +868,14 @@ var queries = []struct { "ROLLBACK", []sql.Row{}, }, + { + "SELECT substring(s, 1, 1) FROM mytable ORDER BY substring(s, 1, 1)", + []sql.Row{{"f"}, {"s"}, {"t"}}, + }, + { + "SELECT substring(s, 1, 1), count(*) FROM mytable GROUP BY substring(s, 1, 1)", + []sql.Row{{"f", int32(1)}, {"s", int32(1)}, {"t", int32(1)}}, + }, } func TestQueries(t *testing.T) { diff --git a/sql/analyzer/resolve_orderby.go b/sql/analyzer/resolve_orderby.go index a678c3594..111a1a2bd 100644 --- a/sql/analyzer/resolve_orderby.go +++ b/sql/analyzer/resolve_orderby.go @@ -35,15 +35,14 @@ func resolveOrderBy(ctx *sql.Context, a *Analyzer, n sql.Node) (sql.Node, error) var colsFromChild []string var missingCols []string for _, f := range sort.SortFields { - n, ok := f.Column.(sql.Nameable) - if !ok { - continue - } + ns := findExprNameables(f.Column) - if stringContains(childNewCols, n.Name()) { - colsFromChild = append(colsFromChild, n.Name()) - } else if !stringContains(schemaCols, n.Name()) { - missingCols = append(missingCols, n.Name()) + for _, n := range ns { + if stringContains(childNewCols, n.Name()) { + colsFromChild = append(colsFromChild, n.Name()) + } else if !stringContains(schemaCols, n.Name()) { + missingCols = append(missingCols, n.Name()) + } } } @@ -221,3 +220,16 @@ func resolveOrderByLiterals(ctx *sql.Context, a *Analyzer, n sql.Node) (sql.Node return plan.NewSort(fields, sort.Child), nil }) } + +func findExprNameables(e sql.Expression) []sql.Nameable { + var result []sql.Nameable + expression.Inspect(e, func(e sql.Expression) bool { + n, ok := e.(sql.Nameable) + if ok { + result = append(result, n) + return false + } + return true + }) + return result +} diff --git a/sql/plan/group_by.go b/sql/plan/group_by.go index cd0b294d7..95036786b 100644 --- a/sql/plan/group_by.go +++ b/sql/plan/group_by.go @@ -353,15 +353,13 @@ func updateBuffer( return n.Update(ctx, buffers[idx], row) case *expression.Alias: return updateBuffer(ctx, buffers, idx, n.Child, row) - case *expression.GetField: + default: val, err := expr.Eval(ctx, row) if err != nil { return err } buffers[idx] = sql.NewRow(val) return nil - default: - return ErrGroupBy.New(n.String()) } } @@ -393,12 +391,10 @@ func evalBuffer( return n.Eval(ctx, buffer) case *expression.Alias: return evalBuffer(ctx, n.Child, buffer) - case *expression.GetField: + default: if len(buffer) > 0 { return buffer[0], nil } return nil, nil - default: - return nil, ErrGroupBy.New(n.String()) } } diff --git a/sql/plan/group_by_test.go b/sql/plan/group_by_test.go index 06441106a..b1709315d 100644 --- a/sql/plan/group_by_test.go +++ b/sql/plan/group_by_test.go @@ -10,7 +10,7 @@ import ( "gopkg.in/src-d/go-mysql-server.v0/sql/expression/function/aggregation" ) -func TestGroupBy_Schema(t *testing.T) { +func TestGroupBySchema(t *testing.T) { require := require.New(t) child := mem.NewTable("test", nil) @@ -25,7 +25,7 @@ func TestGroupBy_Schema(t *testing.T) { }, gb.Schema()) } -func TestGroupBy_Resolved(t *testing.T) { +func TestGroupByResolved(t *testing.T) { require := require.New(t) child := mem.NewTable("test", nil) @@ -42,7 +42,7 @@ func TestGroupBy_Resolved(t *testing.T) { require.False(gb.Resolved()) } -func TestGroupBy_RowIter(t *testing.T) { +func TestGroupByRowIter(t *testing.T) { require := require.New(t) ctx := sql.NewEmptyContext() @@ -96,7 +96,7 @@ func TestGroupBy_RowIter(t *testing.T) { require.Equal(sql.NewRow("col1_2", int64(4444)), rows[1]) } -func TestGroupBy_EvalEmptyBuffer(t *testing.T) { +func TestGroupByEvalEmptyBuffer(t *testing.T) { require := require.New(t) ctx := sql.NewEmptyContext() @@ -105,7 +105,7 @@ func TestGroupBy_EvalEmptyBuffer(t *testing.T) { require.Nil(r) } -func TestGroupBy_Error(t *testing.T) { +func TestGroupByAggregationGrouping(t *testing.T) { require := require.New(t) ctx := sql.NewEmptyContext() @@ -140,8 +140,15 @@ func TestGroupBy_Error(t *testing.T) { NewResolvedTable(child), ) - _, err := sql.NodeToRows(ctx, p) - require.Error(err) + rows, err := sql.NodeToRows(ctx, p) + require.NoError(err) + + expected := []sql.Row{ + {int32(3), false}, + {int32(2), false}, + } + + require.Equal(expected, rows) } func BenchmarkGroupBy(b *testing.B) {