Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add group_concat aggregation support #13331

Merged
merged 9 commits into from
Jun 22, 2023
1 change: 1 addition & 0 deletions go/sqltypes/testing.go
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,7 @@ func MakeTestResult(fields []*querypb.Field, rows ...string) *Result {
result.Rows[i] = make([]Value, len(fields))
for j, col := range split(row) {
if col == "null" {
result.Rows[i][j] = NULL
continue
}
result.Rows[i][j] = MakeTrusted(fields[j].Type, []byte(col))
Expand Down
7 changes: 4 additions & 3 deletions go/test/endtoend/utils/cmp.go
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,7 @@ func (mcmp *MySQLCompare) AssertMatchesAny(query string, expected ...string) {
func (mcmp *MySQLCompare) AssertMatchesAnyNoCompare(query string, expected ...string) {
mcmp.t.Helper()

mQr, vQr := mcmp.execNoCompare(query)
mQr, vQr := mcmp.ExecNoCompare(query)
got := fmt.Sprintf("%v", mQr.Rows)
valid := false
for _, e := range expected {
Expand Down Expand Up @@ -171,7 +171,7 @@ func (mcmp *MySQLCompare) AssertFoundRowsValue(query, workload string, count int
// AssertMatchesNoCompare compares the record of mysql and vitess separately and not with each other.
func (mcmp *MySQLCompare) AssertMatchesNoCompare(query, mExp string, vExp string) {
mcmp.t.Helper()
mQr, vQr := mcmp.execNoCompare(query)
mQr, vQr := mcmp.ExecNoCompare(query)
got := fmt.Sprintf("%v", mQr.Rows)
diff := cmp.Diff(mExp, got)
if diff != "" {
Expand Down Expand Up @@ -200,7 +200,8 @@ func (mcmp *MySQLCompare) Exec(query string) *sqltypes.Result {
return vtQr
}

func (mcmp *MySQLCompare) execNoCompare(query string) (*sqltypes.Result, *sqltypes.Result) {
// ExecNoCompare executes the query on vitess and mysql but does not compare the result with each other.
func (mcmp *MySQLCompare) ExecNoCompare(query string) (*sqltypes.Result, *sqltypes.Result) {
mcmp.t.Helper()
vtQr, err := mcmp.VtConn.ExecuteFetch(query, 1000, true)
require.NoError(mcmp.t, err, "[Vitess Error] for query: "+query)
Expand Down
52 changes: 52 additions & 0 deletions go/test/endtoend/vtgate/queries/aggregation/aggregation_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,10 +18,14 @@ package aggregation

import (
"fmt"
"sort"
"strings"
"testing"

"github.com/stretchr/testify/require"
"golang.org/x/exp/slices"

"vitess.io/vitess/go/sqltypes"
"vitess.io/vitess/go/test/endtoend/cluster"
"vitess.io/vitess/go/test/endtoend/utils"
)
Expand Down Expand Up @@ -475,3 +479,51 @@ func TestComplexAggregation(t *testing.T) {
mcmp.Exec(`SELECT /*vt+ PLANNER=gen4 */ name+COUNT(t1_id)+1 FROM t1 GROUP BY name`)
mcmp.Exec(`SELECT /*vt+ PLANNER=gen4 */ COUNT(*)+shardkey+MIN(t1_id)+1+MAX(t1_id)*SUM(t1_id)+1+name FROM t1 GROUP BY shardkey, name`)
}

// TestGroupConcatAggregation tests the group_concat function with vitess doing the aggregation.
func TestGroupConcatAggregation(t *testing.T) {
mcmp, closer := start(t)
defer closer()
mcmp.Exec("insert into t1(t1_id, `name`, `value`, shardkey) values(1,'a1',null,100), (2,'b1','foo',20), (3,'c1','foo',10), (4,'a1','foo',100), (5,'d1','toto',200), (6,'c1',null,893), (10,'a1','titi',2380), (20,'b1','tete',12833), (9,'e1','yoyo',783493)")
mcmp.Exec("insert into t2(id, shardKey) values (1, 10), (2, 20)")

mQr, vtQr := mcmp.ExecNoCompare(`SELECT /*vt+ PLANNER=gen4 */ group_concat(name) FROM t1`)
compareRow(t, mQr, vtQr, nil, []int{0})
mQr, vtQr = mcmp.ExecNoCompare(`SELECT /*vt+ PLANNER=gen4 */ group_concat(value) FROM t1 join t2 on t1.shardKey = t2.shardKey `)
compareRow(t, mQr, vtQr, nil, []int{0})
mQr, vtQr = mcmp.ExecNoCompare(`SELECT /*vt+ PLANNER=gen4 */ group_concat(value) FROM t1 join t2 on t1.t1_id = t2.shardKey `)
compareRow(t, mQr, vtQr, nil, []int{0})
mQr, vtQr = mcmp.ExecNoCompare(`SELECT /*vt+ PLANNER=gen4 */ group_concat(value) FROM t1 join t2 on t1.shardKey = t2.id `)
compareRow(t, mQr, vtQr, nil, []int{0})
mQr, vtQr = mcmp.ExecNoCompare(`SELECT /*vt+ PLANNER=gen4 */ group_concat(value), t1.name FROM t1, t2 group by t1.name`)
compareRow(t, mQr, vtQr, []int{1}, []int{0})
}

func compareRow(t *testing.T, mRes *sqltypes.Result, vtRes *sqltypes.Result, grpCols []int, fCols []int) {
require.Equal(t, len(mRes.Rows), len(vtRes.Rows), "mysql and vitess result count does not match")
for _, row := range vtRes.Rows {
var grpKey string
for _, col := range grpCols {
grpKey += row[col].String()
}
var foundKey bool
for _, mRow := range mRes.Rows {
var mKey string
for _, col := range grpCols {
mKey += mRow[col].String()
}
if grpKey != mKey {
continue
}
foundKey = true
for _, col := range fCols {
vtFValSplit := strings.Split(row[col].ToString(), ",")
sort.Strings(vtFValSplit)
mFValSplit := strings.Split(mRow[col].ToString(), ",")
sort.Strings(mFValSplit)
require.True(t, slices.Equal(vtFValSplit, mFValSplit), "mysql and vitess result are not same: vitess:%v, mysql:%v", vtRes.Rows, mRes.Rows)
}
}
require.True(t, foundKey, "mysql and vitess result does not same row: vitess:%v, mysql:%v", vtRes.Rows, mRes.Rows)
}
}
15 changes: 15 additions & 0 deletions go/vt/vtgate/engine/opcode/constants.go
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,7 @@ const (
AggregateGtid
AggregateRandom
AggregateCountStar
AggregateGroupConcat
)

var (
Expand Down Expand Up @@ -95,6 +96,7 @@ var SupportedAggregates = map[string]AggregateOpcode{
"vgtid": AggregateGtid,
"count_star": AggregateCountStar,
"random": AggregateRandom,
"group_concat": AggregateGroupConcat,
}

func (code AggregateOpcode) String() string {
Expand All @@ -111,3 +113,16 @@ func (code AggregateOpcode) String() string {
func (code AggregateOpcode) MarshalJSON() ([]byte, error) {
return ([]byte)(fmt.Sprintf("\"%s\"", code.String())), nil
}

// Type returns the opcode return sql type.
func (code AggregateOpcode) Type(field *querypb.Field) querypb.Type {
switch code {
case AggregateGroupConcat:
if sqltypes.IsBinary(field.Type) {
return sqltypes.Blob
}
return sqltypes.Text
default:
return OpcodeType[code]
}
}
53 changes: 38 additions & 15 deletions go/vt/vtgate/engine/ordered_aggregate.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,15 +22,12 @@ import (
"strconv"

"vitess.io/vitess/go/mysql/collations"

"vitess.io/vitess/go/vt/sqlparser"

"vitess.io/vitess/go/sqltypes"
. "vitess.io/vitess/go/vt/vtgate/engine/opcode"
"vitess.io/vitess/go/vt/vtgate/evalengine"

binlogdatapb "vitess.io/vitess/go/vt/proto/binlogdata"
querypb "vitess.io/vitess/go/vt/proto/query"
"vitess.io/vitess/go/vt/sqlparser"
. "vitess.io/vitess/go/vt/vtgate/engine/opcode"
"vitess.io/vitess/go/vt/vtgate/evalengine"
)

var (
Expand Down Expand Up @@ -122,7 +119,12 @@ func (ap *AggregateParams) isDistinct() bool {
}

func (ap *AggregateParams) preProcess() bool {
return ap.Opcode == AggregateCountDistinct || ap.Opcode == AggregateSumDistinct || ap.Opcode == AggregateGtid || ap.Opcode == AggregateCount
switch ap.Opcode {
case AggregateCountDistinct, AggregateSumDistinct, AggregateGtid, AggregateCount, AggregateGroupConcat:
return true
default:
return false
}
}

func (ap *AggregateParams) String() string {
Expand Down Expand Up @@ -177,16 +179,17 @@ func (oa *OrderedAggregate) execute(ctx context.Context, vcursor VCursor, bindVa
if err != nil {
return nil, err
}
fields := convertFields(result.Fields, oa.PreProcess, oa.Aggregates, oa.AggrOnEngine)
out := &sqltypes.Result{
Fields: convertFields(result.Fields, oa.PreProcess, oa.Aggregates, oa.AggrOnEngine),
Fields: fields,
Rows: make([][]sqltypes.Value, 0, len(result.Rows)),
}
// This code is similar to the one in StreamExecute.
var current []sqltypes.Value
var curDistincts []sqltypes.Value
for _, row := range result.Rows {
if current == nil {
current, curDistincts = convertRow(row, oa.PreProcess, oa.Aggregates, oa.AggrOnEngine)
current, curDistincts = convertRow(fields, row, oa.PreProcess, oa.Aggregates, oa.AggrOnEngine)
continue
}
equal, err := oa.keysEqual(current, row)
Expand All @@ -195,14 +198,14 @@ func (oa *OrderedAggregate) execute(ctx context.Context, vcursor VCursor, bindVa
}

if equal {
current, curDistincts, err = merge(result.Fields, current, row, curDistincts, oa.Aggregates)
current, curDistincts, err = merge(fields, current, row, curDistincts, oa.Aggregates)
if err != nil {
return nil, err
}
continue
}
out.Rows = append(out.Rows, current)
current, curDistincts = convertRow(row, oa.PreProcess, oa.Aggregates, oa.AggrOnEngine)
current, curDistincts = convertRow(fields, row, oa.PreProcess, oa.Aggregates, oa.AggrOnEngine)
}

if current != nil {
Expand Down Expand Up @@ -235,7 +238,7 @@ func (oa *OrderedAggregate) TryStreamExecute(ctx context.Context, vcursor VCurso
// This code is similar to the one in Execute.
for _, row := range qr.Rows {
if current == nil {
current, curDistincts = convertRow(row, oa.PreProcess, oa.Aggregates, oa.AggrOnEngine)
current, curDistincts = convertRow(fields, row, oa.PreProcess, oa.Aggregates, oa.AggrOnEngine)
continue
}

Expand All @@ -254,7 +257,7 @@ func (oa *OrderedAggregate) TryStreamExecute(ctx context.Context, vcursor VCurso
if err := cb(&sqltypes.Result{Rows: [][]sqltypes.Value{current}}); err != nil {
return err
}
current, curDistincts = convertRow(row, oa.PreProcess, oa.Aggregates, oa.AggrOnEngine)
current, curDistincts = convertRow(fields, row, oa.PreProcess, oa.Aggregates, oa.AggrOnEngine)
}
return nil
})
Expand All @@ -280,7 +283,7 @@ func convertFields(fields []*querypb.Field, preProcess bool, aggrs []*AggregateP
}
fields[aggr.Col] = &querypb.Field{
Name: aggr.Alias,
Type: OpcodeType[aggr.Opcode],
Type: aggr.Opcode.Type(fields[aggr.Col]),
}
if aggr.isDistinct() {
aggr.KeyCol = aggr.Col
Expand All @@ -289,7 +292,13 @@ func convertFields(fields []*querypb.Field, preProcess bool, aggrs []*AggregateP
return fields
}

func convertRow(row []sqltypes.Value, preProcess bool, aggregates []*AggregateParams, aggrOnEngine bool) (newRow []sqltypes.Value, curDistincts []sqltypes.Value) {
func convertRow(
fields []*querypb.Field,
row []sqltypes.Value,
preProcess bool,
aggregates []*AggregateParams,
aggrOnEngine bool,
) (newRow []sqltypes.Value, curDistincts []sqltypes.Value) {
if !preProcess {
return row, nil
}
Expand Down Expand Up @@ -342,6 +351,10 @@ func convertRow(row []sqltypes.Value, preProcess bool, aggregates []*AggregatePa
data, _ := vgtid.MarshalVT()
val, _ := sqltypes.NewValue(sqltypes.VarBinary, data)
newRow[aggr.Col] = val
case AggregateGroupConcat:
if !row[aggr.Col].IsNull() {
newRow[aggr.Col] = sqltypes.MakeTrusted(fields[aggr.Col].Type, []byte(row[aggr.Col].ToString()))
}
}
}
return newRow, curDistincts
Expand Down Expand Up @@ -466,6 +479,16 @@ func merge(
result[aggr.Col] = val
case AggregateRandom:
// we just grab the first value per grouping. no need to do anything more complicated here
case AggregateGroupConcat:
if row2[aggr.Col].IsNull() {
break
}
if result[aggr.Col].IsNull() {
result[aggr.Col] = sqltypes.MakeTrusted(fields[aggr.Col].Type, []byte(row2[aggr.Col].ToString()))
break
}
concat := row1[aggr.Col].ToString() + "," + row2[aggr.Col].ToString()
result[aggr.Col] = sqltypes.MakeTrusted(fields[aggr.Col].Type, []byte(concat))
default:
return nil, nil, fmt.Errorf("BUG: Unexpected opcode: %v", aggr.Opcode)
}
Expand Down
Loading
Loading