diff --git a/engine_test.go b/engine_test.go index f8158d616..5e9b9964b 100644 --- a/engine_test.go +++ b/engine_test.go @@ -201,6 +201,32 @@ var queries = []struct { {int64(1), "third row"}, }, }, + { + `SELECT fi, COUNT(*) FROM ( + SELECT tbl.s AS fi + FROM mytable tbl + ) t + GROUP BY fi + ORDER BY COUNT(*) ASC`, + []sql.Row{ + {"first row", int64(1)}, + {"second row", int64(1)}, + {"third row", int64(1)}, + }, + }, + { + `SELECT COUNT(*), fi FROM ( + SELECT tbl.s AS fi + FROM mytable tbl + ) t + GROUP BY fi + ORDER BY COUNT(*) ASC`, + []sql.Row{ + {int64(1), "first row"}, + {int64(1), "second row"}, + {int64(1), "third row"}, + }, + }, { `SELECT COUNT(*) as cnt, fi FROM ( SELECT tbl.s AS fi @@ -367,7 +393,7 @@ var queries = []struct { }, }, { - `SELECT COUNT(*) c, i as foo FROM mytable GROUP BY i ORDER BY foo, i DESC`, + `SELECT COUNT(*) c, i as foo FROM mytable GROUP BY i ORDER BY foo DESC`, []sql.Row{ {int64(1), int64(3)}, {int64(1), int64(2)}, @@ -375,7 +401,7 @@ var queries = []struct { }, }, { - `SELECT COUNT(*) c, i as foo FROM mytable GROUP BY 2 ORDER BY foo, i DESC`, + `SELECT COUNT(*) c, i as foo FROM mytable GROUP BY 2 ORDER BY foo DESC`, []sql.Row{ {int64(1), int64(3)}, {int64(1), int64(2)}, @@ -442,6 +468,22 @@ var queries = []struct { {float64(4), int64(3)}, }, }, + { + "SELECT SUM(i), i FROM mytable GROUP BY i ORDER BY 1+SUM(i) ASC", + []sql.Row{ + {float64(1), int64(1)}, + {float64(2), int64(2)}, + {float64(3), int64(3)}, + }, + }, + { + "SELECT i, SUM(i) FROM mytable GROUP BY i ORDER BY SUM(i) DESC", + []sql.Row{ + {int64(3), float64(3)}, + {int64(2), float64(2)}, + {int64(1), float64(1)}, + }, + }, { `/*!40101 SET NAMES utf8 */`, []sql.Row{}, @@ -978,19 +1020,33 @@ var queries = []struct { `, []sql.Row{{"s"}, {"s2"}}, }, + { + "SELECT s, i FROM mytable GROUP BY i ORDER BY SUBSTRING(s, 1, 1) DESC", + []sql.Row{ + {string("third row"), int64(3)}, + {string("second row"), int64(2)}, + {string("first row"), int64(1)}, + }, + }, + { + "SELECT s, i FROM mytable GROUP BY i HAVING count(*) > 0 ORDER BY SUBSTRING(s, 1, 1) DESC", + []sql.Row{ + {string("third row"), int64(3)}, + {string("second row"), int64(2)}, + {string("first row"), int64(1)}, + }, + }, } func TestQueries(t *testing.T) { e := newEngine(t) - - ep := newEngineWithParallelism(t, 2) - t.Run("sequential", func(t *testing.T) { for _, tt := range queries { testQuery(t, e, tt.query, tt.expected) } }) + ep := newEngineWithParallelism(t, 2) t.Run("parallel", func(t *testing.T) { for _, tt := range queries { testQuery(t, ep, tt.query, tt.expected) @@ -1589,6 +1645,8 @@ func testQuery(t *testing.T, e *sqle.Engine, q string, expected []sql.Row) { } func testQueryWithContext(ctx *sql.Context, t *testing.T, e *sqle.Engine, q string, expected []sql.Row) { + orderBy := strings.Contains(strings.ToUpper(q), " ORDER BY ") + t.Run(q, func(t *testing.T) { require := require.New(t) _, iter, err := e.Query(ctx, q) @@ -1597,7 +1655,11 @@ func testQueryWithContext(ctx *sql.Context, t *testing.T, e *sqle.Engine, q stri rows, err := sql.RowIterToRows(iter) require.NoError(err) - require.ElementsMatch(expected, rows) + if orderBy { + require.Equal(expected, rows) + } else { + require.ElementsMatch(expected, rows) + } }) } diff --git a/sql/analyzer/resolve_orderby.go b/sql/analyzer/resolve_orderby.go index 111a1a2bd..bdea67bed 100644 --- a/sql/analyzer/resolve_orderby.go +++ b/sql/analyzer/resolve_orderby.go @@ -162,6 +162,8 @@ func pushSortDown(sort *plan.Sort) (sql.Node, error) { child.Grouping, plan.NewSort(sort.SortFields, child.Child), ), nil + case *plan.ResolvedTable: + return child, nil default: // Can't do anything here, there should be either a project or a groupby // below an order by. @@ -183,7 +185,14 @@ func resolveOrderByLiterals(ctx *sql.Context, a *Analyzer, n sql.Node) (sql.Node return n, nil } - var fields = make([]plan.SortField, len(sort.SortFields)) + schema := sort.Child.Schema() + var ( + fields = make([]plan.SortField, len(sort.SortFields)) + schemaCols = make([]string, len(schema)) + ) + for i, col := range sort.Child.Schema() { + schemaCols[i] = col.Name + } for i, f := range sort.SortFields { if lit, ok := f.Column.(*expression.Literal); ok && sql.IsNumber(f.Column.Type()) { // it is safe to eval literals with no context and/or row @@ -199,21 +208,32 @@ func resolveOrderByLiterals(ctx *sql.Context, a *Analyzer, n sql.Node) (sql.Node // column access is 1-indexed idx := int(v.(int64)) - 1 - - schema := sort.Child.Schema() if idx >= len(schema) || idx < 0 { return nil, ErrOrderByColumnIndex.New(idx + 1) } fields[i] = plan.SortField{ - Column: expression.NewUnresolvedColumn(schema[idx].Name), + Column: expression.NewUnresolvedColumn(schemaCols[idx]), Order: f.Order, NullOrdering: f.NullOrdering, } - a.Log("replaced order by column %d with %s", idx+1, schema[idx].Name) + a.Log("replaced order by column %d with %s", idx+1, schemaCols[idx]) } else { - fields[i] = f + if agg, ok := f.Column.(sql.Aggregation); ok { + name := agg.String() + if nameable, ok := f.Column.(sql.Nameable); ok { + name = nameable.Name() + } + + fields[i] = plan.SortField{ + Column: expression.NewUnresolvedColumn(name), + Order: f.Order, + NullOrdering: f.NullOrdering, + } + } else { + fields[i] = f + } } }