Skip to content
This repository has been archived by the owner on Jan 28, 2021. It is now read-only.

sql: implement count distinct #785

Merged
merged 1 commit into from
Jul 8, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions .travis.yml
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ jobs:
- make TEST=ruby integration

- language: java
jdk: openjdk10
jdk: openjdk8
before_install:
- eval "$(gimme 1.12.4)"
install:
Expand Down Expand Up @@ -102,4 +102,4 @@ jobs:
install:
- go get ./...
script:
- make TEST=c integration
- make TEST=c integration
2 changes: 1 addition & 1 deletion SUPPORTED.md
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@

## Grouping expressions
- AVG
- COUNT
- COUNT and COUNT(DISTINCT)
- MAX
- MIN
- SUM (always returns DOUBLE)
Expand Down
4 changes: 4 additions & 0 deletions engine_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1291,6 +1291,10 @@ var queries = []struct {
`SELECT LAST(i) FROM (SELECT i FROM mytable ORDER BY i) t`,
[]sql.Row{{int64(3)}},
},
{
`SELECT COUNT(DISTINCT t.i) FROM tabletest t, mytable t2`,
[]sql.Row{{int64(3)}},
},
}

func TestQueries(t *testing.T) {
Expand Down
5 changes: 5 additions & 0 deletions sql/analyzer/resolve_having.go
Original file line number Diff line number Diff line change
Expand Up @@ -383,6 +383,11 @@ func aggregationEquals(a, b sql.Expression) bool {
// the same.
_, ok := b.(*aggregation.Count)
return ok
case *aggregation.CountDistinct:
// it doesn't matter what's inside a Count, the result will be
// the same.
_, ok := b.(*aggregation.CountDistinct)
return ok
case *aggregation.Sum:
b, ok := b.(*aggregation.Sum)
if !ok {
Expand Down
91 changes: 91 additions & 0 deletions sql/expression/function/aggregation/count.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package aggregation
import (
"fmt"

"github.com/mitchellh/hashstructure"
"github.com/src-d/go-mysql-server/sql"
"github.com/src-d/go-mysql-server/sql/expression"
)
Expand Down Expand Up @@ -87,3 +88,93 @@ func (c *Count) Eval(ctx *sql.Context, buffer sql.Row) (interface{}, error) {
count := buffer[0]
return count, nil
}

// CountDistinct node to count how many rows are in the result set.
type CountDistinct struct {
expression.UnaryExpression
}

// NewCountDistinct creates a new CountDistinct node.
func NewCountDistinct(e sql.Expression) *CountDistinct {
return &CountDistinct{expression.UnaryExpression{Child: e}}
}

// NewBuffer creates a new buffer for the aggregation.
func (c *CountDistinct) NewBuffer() sql.Row {
return sql.NewRow(make(map[uint64]struct{}))
}

// Type returns the type of the result.
func (c *CountDistinct) Type() sql.Type {
return sql.Int64
}

// IsNullable returns whether the return value can be null.
func (c *CountDistinct) IsNullable() bool {
return false
}

// Resolved implements the Expression interface.
func (c *CountDistinct) Resolved() bool {
if _, ok := c.Child.(*expression.Star); ok {
return true
}

return c.Child.Resolved()
}

func (c *CountDistinct) String() string {
return fmt.Sprintf("COUNT(DISTINCT %s)", c.Child)
}

// WithChildren implements the Expression interface.
func (c *CountDistinct) WithChildren(children ...sql.Expression) (sql.Expression, error) {
if len(children) != 1 {
return nil, sql.ErrInvalidChildrenNumber.New(c, len(children), 1)
}
return NewCountDistinct(children[0]), nil
}

// Update implements the Aggregation interface.
func (c *CountDistinct) Update(ctx *sql.Context, buffer, row sql.Row) error {
seen := buffer[0].(map[uint64]struct{})
var value interface{}
if _, ok := c.Child.(*expression.Star); ok {
value = row
} else {
v, err := c.Child.Eval(ctx, row)
if v == nil {
return nil
}

if err != nil {
return err
}

value = v
}

hash, err := hashstructure.Hash(value, nil)
if err != nil {
return fmt.Errorf("count distinct unable to hash value: %s", err)
}

seen[hash] = struct{}{}

return nil
}

// Merge implements the Aggregation interface.
func (c *CountDistinct) Merge(ctx *sql.Context, buffer, partial sql.Row) error {
seen := buffer[0].(map[uint64]struct{})
for k := range partial[0].(map[uint64]struct{}) {
seen[k] = struct{}{}
}
return nil
}

// Eval implements the Aggregation interface.
func (c *CountDistinct) Eval(ctx *sql.Context, buffer sql.Row) (interface{}, error) {
seen := buffer[0].(map[uint64]struct{})
return int64(len(seen)), nil
}
92 changes: 71 additions & 21 deletions sql/expression/function/aggregation/count_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,19 +3,12 @@ package aggregation
import (
"testing"

"github.com/stretchr/testify/require"
"github.com/src-d/go-mysql-server/sql"
"github.com/src-d/go-mysql-server/sql/expression"
"github.com/stretchr/testify/require"
)

func TestCount_String(t *testing.T) {
require := require.New(t)

c := NewCount(expression.NewLiteral("foo", sql.Text))
require.Equal(`COUNT("foo")`, c.String())
}

func TestCount_Eval_1(t *testing.T) {
func TestCountEval1(t *testing.T) {
require := require.New(t)
ctx := sql.NewEmptyContext()

Expand All @@ -37,39 +30,96 @@ func TestCount_Eval_1(t *testing.T) {
require.Equal(int64(7), eval(t, c, b))
}

func TestCount_Eval_Star(t *testing.T) {
func TestCountEvalStar(t *testing.T) {
require := require.New(t)
ctx := sql.NewEmptyContext()

c := NewCount(expression.NewStar())
b := c.NewBuffer()
require.Equal(int64(0), eval(t, c, b))

c.Update(ctx, b, nil)
c.Update(ctx, b, sql.NewRow("foo"))
c.Update(ctx, b, sql.NewRow(1))
c.Update(ctx, b, sql.NewRow(nil))
c.Update(ctx, b, sql.NewRow(1, 2, 3))
require.NoError(c.Update(ctx, b, nil))
require.NoError(c.Update(ctx, b, sql.NewRow("foo")))
require.NoError(c.Update(ctx, b, sql.NewRow(1)))
require.NoError(c.Update(ctx, b, sql.NewRow(nil)))
require.NoError(c.Update(ctx, b, sql.NewRow(1, 2, 3)))
require.Equal(int64(5), eval(t, c, b))

b2 := c.NewBuffer()
c.Update(ctx, b2, sql.NewRow())
c.Update(ctx, b2, sql.NewRow("foo"))
c.Merge(ctx, b, b2)
require.NoError(c.Update(ctx, b2, sql.NewRow()))
require.NoError(c.Update(ctx, b2, sql.NewRow("foo")))
require.NoError(c.Merge(ctx, b, b2))
require.Equal(int64(7), eval(t, c, b))
}

func TestCount_Eval_String(t *testing.T) {
func TestCountEvalString(t *testing.T) {
require := require.New(t)
ctx := sql.NewEmptyContext()

c := NewCount(expression.NewGetField(0, sql.Text, "", true))
b := c.NewBuffer()
require.Equal(int64(0), eval(t, c, b))

c.Update(ctx, b, sql.NewRow("foo"))
require.NoError(c.Update(ctx, b, sql.NewRow("foo")))
require.Equal(int64(1), eval(t, c, b))

c.Update(ctx, b, sql.NewRow(nil))
require.NoError(c.Update(ctx, b, sql.NewRow(nil)))
require.Equal(int64(1), eval(t, c, b))
}

func TestCountDistinctEval1(t *testing.T) {
require := require.New(t)
ctx := sql.NewEmptyContext()

c := NewCountDistinct(expression.NewLiteral(1, sql.Int32))
b := c.NewBuffer()
require.Equal(int64(0), eval(t, c, b))

require.NoError(c.Update(ctx, b, nil))
require.NoError(c.Update(ctx, b, sql.NewRow("foo")))
require.NoError(c.Update(ctx, b, sql.NewRow(1)))
require.NoError(c.Update(ctx, b, sql.NewRow(nil)))
require.NoError(c.Update(ctx, b, sql.NewRow(1, 2, 3)))
require.Equal(int64(1), eval(t, c, b))
}

func TestCountDistinctEvalStar(t *testing.T) {
require := require.New(t)
ctx := sql.NewEmptyContext()

c := NewCountDistinct(expression.NewStar())
b := c.NewBuffer()
require.Equal(int64(0), eval(t, c, b))

require.NoError(c.Update(ctx, b, nil))
require.NoError(c.Update(ctx, b, sql.NewRow("foo")))
require.NoError(c.Update(ctx, b, sql.NewRow(1)))
require.NoError(c.Update(ctx, b, sql.NewRow(nil)))
require.NoError(c.Update(ctx, b, sql.NewRow(1, 2, 3)))
require.Equal(int64(5), eval(t, c, b))

b2 := c.NewBuffer()
require.NoError(c.Update(ctx, b2, sql.NewRow(1)))
require.NoError(c.Update(ctx, b2, sql.NewRow("foo")))
require.NoError(c.Update(ctx, b2, sql.NewRow(5)))
require.NoError(c.Merge(ctx, b, b2))

require.Equal(int64(6), eval(t, c, b))
}

func TestCountDistinctEvalString(t *testing.T) {
require := require.New(t)
ctx := sql.NewEmptyContext()

c := NewCountDistinct(expression.NewGetField(0, sql.Text, "", true))
b := c.NewBuffer()
require.Equal(int64(0), eval(t, c, b))

require.NoError(c.Update(ctx, b, sql.NewRow("foo")))
require.Equal(int64(1), eval(t, c, b))

require.NoError(c.Update(ctx, b, sql.NewRow(nil)))
require.NoError(c.Update(ctx, b, sql.NewRow("foo")))
require.NoError(c.Update(ctx, b, sql.NewRow("bar")))
require.Equal(int64(2), eval(t, c, b))
}
19 changes: 15 additions & 4 deletions sql/parse/parse.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ import (
"github.com/src-d/go-mysql-server/sql"
"github.com/src-d/go-mysql-server/sql/expression"
"github.com/src-d/go-mysql-server/sql/expression/function"
"github.com/src-d/go-mysql-server/sql/expression/function/aggregation"
"github.com/src-d/go-mysql-server/sql/plan"
"gopkg.in/src-d/go-errors.v1"
"vitess.io/vitess/go/vt/sqlparser"
Expand Down Expand Up @@ -659,9 +660,11 @@ func getInt64Value(ctx *sql.Context, expr sqlparser.Expr, errStr string) (int64,
func isAggregate(e sql.Expression) bool {
var isAgg bool
expression.Inspect(e, func(e sql.Expression) bool {
fn, ok := e.(*expression.UnresolvedFunction)
if ok {
isAgg = isAgg || fn.IsAggregate
switch e := e.(type) {
case *expression.UnresolvedFunction:
isAgg = isAgg || e.IsAggregate
case *aggregation.CountDistinct:
isAgg = true
}

return true
Expand Down Expand Up @@ -789,7 +792,15 @@ func exprToExpression(e sqlparser.Expr) (sql.Expression, error) {
}

if v.Distinct {
return nil, ErrUnsupportedSyntax.New("DISTINCT on aggregations")
if v.Name.Lowered() != "count" {
return nil, ErrUnsupportedSyntax.New("DISTINCT on non-COUNT aggregations")
}

if len(exprs) != 1 {
return nil, ErrUnsupportedSyntax.New("more than one expression in COUNT")
}

return aggregation.NewCountDistinct(exprs[0]), nil
}

return expression.NewUnresolvedFunction(v.Name.Lowered(),
Expand Down
10 changes: 9 additions & 1 deletion sql/parse/parse_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import (
"testing"

"github.com/src-d/go-mysql-server/sql/expression"
"github.com/src-d/go-mysql-server/sql/expression/function/aggregation"
"github.com/src-d/go-mysql-server/sql/plan"
"gopkg.in/src-d/go-errors.v1"

Expand Down Expand Up @@ -1158,6 +1159,13 @@ var fixtures = map[string]sql.Node{
[]sql.Expression{},
plan.NewUnresolvedTable("foo", ""),
),
`SELECT COUNT(DISTINCT i) FROM foo`: plan.NewGroupBy(
[]sql.Expression{
aggregation.NewCountDistinct(expression.NewUnresolvedColumn("i")),
},
[]sql.Expression{},
plan.NewUnresolvedTable("foo", ""),
),
}

func TestParse(t *testing.T) {
Expand Down Expand Up @@ -1191,7 +1199,7 @@ var fixturesErrors = map[string]*errors.Kind{
`SELECT '2018-05-01' / INTERVAL 1 DAY`: ErrUnsupportedSyntax,
`SELECT INTERVAL 1 DAY + INTERVAL 1 DAY`: ErrUnsupportedSyntax,
`SELECT '2018-05-01' + (INTERVAL 1 DAY + INTERVAL 1 DAY)`: ErrUnsupportedSyntax,
`SELECT COUNT(DISTINCT foo) FROM b`: ErrUnsupportedSyntax,
`SELECT AVG(DISTINCT foo) FROM b`: ErrUnsupportedSyntax,
}

func TestParseErrors(t *testing.T) {
Expand Down