Skip to content
This repository has been archived by the owner on Jan 28, 2021. It is now read-only.

Commit

Permalink
sql: implement count distinct
Browse files Browse the repository at this point in the history
Signed-off-by: Miguel Molina <miguel@erizocosmi.co>
  • Loading branch information
erizocosmico committed Jul 5, 2019
1 parent 7f8224b commit 166881a
Show file tree
Hide file tree
Showing 7 changed files with 196 additions and 27 deletions.
2 changes: 1 addition & 1 deletion SUPPORTED.md
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@

## Grouping expressions
- AVG
- COUNT
- COUNT and COUNT(DISTINCT)
- MAX
- MIN
- SUM (always returns DOUBLE)
Expand Down
4 changes: 4 additions & 0 deletions engine_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1291,6 +1291,10 @@ var queries = []struct {
`SELECT LAST(i) FROM (SELECT i FROM mytable ORDER BY i) t`,
[]sql.Row{{int64(3)}},
},
{
`SELECT COUNT(DISTINCT t.i) FROM tabletest t, mytable t2`,
[]sql.Row{{int64(3)}},
},
}

func TestQueries(t *testing.T) {
Expand Down
5 changes: 5 additions & 0 deletions sql/analyzer/resolve_having.go
Original file line number Diff line number Diff line change
Expand Up @@ -383,6 +383,11 @@ func aggregationEquals(a, b sql.Expression) bool {
// the same.
_, ok := b.(*aggregation.Count)
return ok
case *aggregation.CountDistinct:
// it doesn't matter what's inside a Count, the result will be
// the same.
_, ok := b.(*aggregation.CountDistinct)
return ok
case *aggregation.Sum:
b, ok := b.(*aggregation.Sum)
if !ok {
Expand Down
91 changes: 91 additions & 0 deletions sql/expression/function/aggregation/count.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package aggregation
import (
"fmt"

"github.com/mitchellh/hashstructure"
"github.com/src-d/go-mysql-server/sql"
"github.com/src-d/go-mysql-server/sql/expression"
)
Expand Down Expand Up @@ -87,3 +88,93 @@ func (c *Count) Eval(ctx *sql.Context, buffer sql.Row) (interface{}, error) {
count := buffer[0]
return count, nil
}

// CountDistinct node to count how many rows are in the result set.
type CountDistinct struct {
expression.UnaryExpression
}

// NewCountDistinct creates a new CountDistinct node.
func NewCountDistinct(e sql.Expression) *CountDistinct {
return &CountDistinct{expression.UnaryExpression{Child: e}}
}

// NewBuffer creates a new buffer for the aggregation.
func (c *CountDistinct) NewBuffer() sql.Row {
return sql.NewRow(make(map[uint64]struct{}))
}

// Type returns the type of the result.
func (c *CountDistinct) Type() sql.Type {
return sql.Int64
}

// IsNullable returns whether the return value can be null.
func (c *CountDistinct) IsNullable() bool {
return false
}

// Resolved implements the Expression interface.
func (c *CountDistinct) Resolved() bool {
if _, ok := c.Child.(*expression.Star); ok {
return true
}

return c.Child.Resolved()
}

func (c *CountDistinct) String() string {
return fmt.Sprintf("COUNT(DISTINCT %s)", c.Child)
}

// WithChildren implements the Expression interface.
func (c *CountDistinct) WithChildren(children ...sql.Expression) (sql.Expression, error) {
if len(children) != 1 {
return nil, sql.ErrInvalidChildrenNumber.New(c, len(children), 1)
}
return NewCountDistinct(children[0]), nil
}

// Update implements the Aggregation interface.
func (c *CountDistinct) Update(ctx *sql.Context, buffer, row sql.Row) error {
seen := buffer[0].(map[uint64]struct{})
var value interface{}
if _, ok := c.Child.(*expression.Star); ok {
value = row
} else {
v, err := c.Child.Eval(ctx, row)
if v == nil {
return nil
}

if err != nil {
return err
}

value = v
}

hash, err := hashstructure.Hash(value, nil)
if err != nil {
return fmt.Errorf("count distinct unable to hash value: %s", err)
}

seen[hash] = struct{}{}

return nil
}

// Merge implements the Aggregation interface.
func (c *CountDistinct) Merge(ctx *sql.Context, buffer, partial sql.Row) error {
seen := buffer[0].(map[uint64]struct{})
for k := range partial[0].(map[uint64]struct{}) {
seen[k] = struct{}{}
}
return nil
}

// Eval implements the Aggregation interface.
func (c *CountDistinct) Eval(ctx *sql.Context, buffer sql.Row) (interface{}, error) {
seen := buffer[0].(map[uint64]struct{})
return int64(len(seen)), nil
}
92 changes: 71 additions & 21 deletions sql/expression/function/aggregation/count_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,19 +3,12 @@ package aggregation
import (
"testing"

"github.com/stretchr/testify/require"
"github.com/src-d/go-mysql-server/sql"
"github.com/src-d/go-mysql-server/sql/expression"
"github.com/stretchr/testify/require"
)

func TestCount_String(t *testing.T) {
require := require.New(t)

c := NewCount(expression.NewLiteral("foo", sql.Text))
require.Equal(`COUNT("foo")`, c.String())
}

func TestCount_Eval_1(t *testing.T) {
func TestCountEval1(t *testing.T) {
require := require.New(t)
ctx := sql.NewEmptyContext()

Expand All @@ -37,39 +30,96 @@ func TestCount_Eval_1(t *testing.T) {
require.Equal(int64(7), eval(t, c, b))
}

func TestCount_Eval_Star(t *testing.T) {
func TestCountEvalStar(t *testing.T) {
require := require.New(t)
ctx := sql.NewEmptyContext()

c := NewCount(expression.NewStar())
b := c.NewBuffer()
require.Equal(int64(0), eval(t, c, b))

c.Update(ctx, b, nil)
c.Update(ctx, b, sql.NewRow("foo"))
c.Update(ctx, b, sql.NewRow(1))
c.Update(ctx, b, sql.NewRow(nil))
c.Update(ctx, b, sql.NewRow(1, 2, 3))
require.NoError(c.Update(ctx, b, nil))
require.NoError(c.Update(ctx, b, sql.NewRow("foo")))
require.NoError(c.Update(ctx, b, sql.NewRow(1)))
require.NoError(c.Update(ctx, b, sql.NewRow(nil)))
require.NoError(c.Update(ctx, b, sql.NewRow(1, 2, 3)))
require.Equal(int64(5), eval(t, c, b))

b2 := c.NewBuffer()
c.Update(ctx, b2, sql.NewRow())
c.Update(ctx, b2, sql.NewRow("foo"))
c.Merge(ctx, b, b2)
require.NoError(c.Update(ctx, b2, sql.NewRow()))
require.NoError(c.Update(ctx, b2, sql.NewRow("foo")))
require.NoError(c.Merge(ctx, b, b2))
require.Equal(int64(7), eval(t, c, b))
}

func TestCount_Eval_String(t *testing.T) {
func TestCountEvalString(t *testing.T) {
require := require.New(t)
ctx := sql.NewEmptyContext()

c := NewCount(expression.NewGetField(0, sql.Text, "", true))
b := c.NewBuffer()
require.Equal(int64(0), eval(t, c, b))

c.Update(ctx, b, sql.NewRow("foo"))
require.NoError(c.Update(ctx, b, sql.NewRow("foo")))
require.Equal(int64(1), eval(t, c, b))

c.Update(ctx, b, sql.NewRow(nil))
require.NoError(c.Update(ctx, b, sql.NewRow(nil)))
require.Equal(int64(1), eval(t, c, b))
}

func TestCountDistinctEval1(t *testing.T) {
require := require.New(t)
ctx := sql.NewEmptyContext()

c := NewCountDistinct(expression.NewLiteral(1, sql.Int32))
b := c.NewBuffer()
require.Equal(int64(0), eval(t, c, b))

require.NoError(c.Update(ctx, b, nil))
require.NoError(c.Update(ctx, b, sql.NewRow("foo")))
require.NoError(c.Update(ctx, b, sql.NewRow(1)))
require.NoError(c.Update(ctx, b, sql.NewRow(nil)))
require.NoError(c.Update(ctx, b, sql.NewRow(1, 2, 3)))
require.Equal(int64(1), eval(t, c, b))
}

func TestCountDistinctEvalStar(t *testing.T) {
require := require.New(t)
ctx := sql.NewEmptyContext()

c := NewCountDistinct(expression.NewStar())
b := c.NewBuffer()
require.Equal(int64(0), eval(t, c, b))

require.NoError(c.Update(ctx, b, nil))
require.NoError(c.Update(ctx, b, sql.NewRow("foo")))
require.NoError(c.Update(ctx, b, sql.NewRow(1)))
require.NoError(c.Update(ctx, b, sql.NewRow(nil)))
require.NoError(c.Update(ctx, b, sql.NewRow(1, 2, 3)))
require.Equal(int64(5), eval(t, c, b))

b2 := c.NewBuffer()
require.NoError(c.Update(ctx, b2, sql.NewRow(1)))
require.NoError(c.Update(ctx, b2, sql.NewRow("foo")))
require.NoError(c.Update(ctx, b2, sql.NewRow(5)))
require.NoError(c.Merge(ctx, b, b2))

require.Equal(int64(6), eval(t, c, b))
}

func TestCountDistinctEvalString(t *testing.T) {
require := require.New(t)
ctx := sql.NewEmptyContext()

c := NewCountDistinct(expression.NewGetField(0, sql.Text, "", true))
b := c.NewBuffer()
require.Equal(int64(0), eval(t, c, b))

require.NoError(c.Update(ctx, b, sql.NewRow("foo")))
require.Equal(int64(1), eval(t, c, b))

require.NoError(c.Update(ctx, b, sql.NewRow(nil)))
require.NoError(c.Update(ctx, b, sql.NewRow("foo")))
require.NoError(c.Update(ctx, b, sql.NewRow("bar")))
require.Equal(int64(2), eval(t, c, b))
}
19 changes: 15 additions & 4 deletions sql/parse/parse.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ import (
"github.com/src-d/go-mysql-server/sql"
"github.com/src-d/go-mysql-server/sql/expression"
"github.com/src-d/go-mysql-server/sql/expression/function"
"github.com/src-d/go-mysql-server/sql/expression/function/aggregation"
"github.com/src-d/go-mysql-server/sql/plan"
"gopkg.in/src-d/go-errors.v1"
"vitess.io/vitess/go/vt/sqlparser"
Expand Down Expand Up @@ -659,9 +660,11 @@ func getInt64Value(ctx *sql.Context, expr sqlparser.Expr, errStr string) (int64,
func isAggregate(e sql.Expression) bool {
var isAgg bool
expression.Inspect(e, func(e sql.Expression) bool {
fn, ok := e.(*expression.UnresolvedFunction)
if ok {
isAgg = isAgg || fn.IsAggregate
switch e := e.(type) {
case *expression.UnresolvedFunction:
isAgg = isAgg || e.IsAggregate
case *aggregation.CountDistinct:
isAgg = true
}

return true
Expand Down Expand Up @@ -789,7 +792,15 @@ func exprToExpression(e sqlparser.Expr) (sql.Expression, error) {
}

if v.Distinct {
return nil, ErrUnsupportedSyntax.New("DISTINCT on aggregations")
if v.Name.Lowered() != "count" {
return nil, ErrUnsupportedSyntax.New("DISTINCT on non-COUNT aggregations")
}

if len(exprs) != 1 {
return nil, ErrUnsupportedSyntax.New("more than one expression in COUNT")
}

return aggregation.NewCountDistinct(exprs[0]), nil
}

return expression.NewUnresolvedFunction(v.Name.Lowered(),
Expand Down
10 changes: 9 additions & 1 deletion sql/parse/parse_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import (
"testing"

"github.com/src-d/go-mysql-server/sql/expression"
"github.com/src-d/go-mysql-server/sql/expression/function/aggregation"
"github.com/src-d/go-mysql-server/sql/plan"
"gopkg.in/src-d/go-errors.v1"

Expand Down Expand Up @@ -1158,6 +1159,13 @@ var fixtures = map[string]sql.Node{
[]sql.Expression{},
plan.NewUnresolvedTable("foo", ""),
),
`SELECT COUNT(DISTINCT i) FROM foo`: plan.NewGroupBy(
[]sql.Expression{
aggregation.NewCountDistinct(expression.NewUnresolvedColumn("i")),
},
[]sql.Expression{},
plan.NewUnresolvedTable("foo", ""),
),
}

func TestParse(t *testing.T) {
Expand Down Expand Up @@ -1191,7 +1199,7 @@ var fixturesErrors = map[string]*errors.Kind{
`SELECT '2018-05-01' / INTERVAL 1 DAY`: ErrUnsupportedSyntax,
`SELECT INTERVAL 1 DAY + INTERVAL 1 DAY`: ErrUnsupportedSyntax,
`SELECT '2018-05-01' + (INTERVAL 1 DAY + INTERVAL 1 DAY)`: ErrUnsupportedSyntax,
`SELECT COUNT(DISTINCT foo) FROM b`: ErrUnsupportedSyntax,
`SELECT AVG(DISTINCT foo) FROM b`: ErrUnsupportedSyntax,
}

func TestParseErrors(t *testing.T) {
Expand Down

0 comments on commit 166881a

Please sign in to comment.