Skip to content

Commit

Permalink
expression: refactor grouping function computation and update related…
Browse files Browse the repository at this point in the history
… tipb (#44436)

close #44437
  • Loading branch information
AilinKid committed Jun 7, 2023
1 parent 8b67cea commit aedbcd0
Show file tree
Hide file tree
Showing 8 changed files with 106 additions and 63 deletions.
4 changes: 2 additions & 2 deletions DEPS.bzl
Original file line number Diff line number Diff line change
Expand Up @@ -3472,8 +3472,8 @@ def go_deps():
name = "com_github_pingcap_tipb",
build_file_proto_mode = "disable_global",
importpath = "github.com/pingcap/tipb",
sum = "h1:f0c37nxxOl7C40+mC5bO9+IbVf8ia1frMU/WD0Heo4E=",
version = "v0.0.0-20230523034258-1bbc3bbbd369",
sum = "h1:J2HQyR5v1AcoBzx5/AYJW9XFSIl6si6YoC6yGI1W89c=",
version = "v0.0.0-20230602100112-acb7942db1ca",
)
go_repository(
name = "com_github_pkg_browser",
Expand Down
2 changes: 1 addition & 1 deletion executor/test/showtest/show_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1524,7 +1524,7 @@ func TestShowBuiltin(t *testing.T) {
res := tk.MustQuery("show builtins;")
require.NotNil(t, res)
rows := res.Rows()
const builtinFuncNum = 289
const builtinFuncNum = 290
require.Equal(t, builtinFuncNum, len(rows))
require.Equal(t, rows[0][0].(string), "abs")
require.Equal(t, rows[builtinFuncNum-1][0].(string), "yearweek")
Expand Down
1 change: 1 addition & 0 deletions expression/builtin.go
Original file line number Diff line number Diff line change
Expand Up @@ -800,6 +800,7 @@ var funcs = map[string]functionClass{
ast.BinToUUID: &binToUUIDFunctionClass{baseFunctionClass{ast.BinToUUID, 1, 2}},
ast.TiDBShard: &tidbShardFunctionClass{baseFunctionClass{ast.TiDBShard, 1, 1}},
ast.TiDBRowChecksum: &tidbRowChecksumFunctionClass{baseFunctionClass{ast.TiDBRowChecksum, 0, 0}},
ast.Grouping: &groupingImplFunctionClass{baseFunctionClass{ast.Grouping, 1, 1}},

ast.GetLock: &lockFunctionClass{baseFunctionClass{ast.GetLock, 2, 2}},
ast.ReleaseLock: &releaseLockFunctionClass{baseFunctionClass{ast.ReleaseLock, 1, 1}},
Expand Down
121 changes: 81 additions & 40 deletions expression/builtin_grouping.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ package expression
import (
"github.com/gogo/protobuf/proto"
"github.com/pingcap/errors"
"github.com/pingcap/tidb/parser/mysql"
"github.com/pingcap/tidb/sessionctx"
"github.com/pingcap/tidb/types"
"github.com/pingcap/tidb/util/chunk"
Expand Down Expand Up @@ -44,8 +45,9 @@ func (c *groupingImplFunctionClass) getFunction(ctx sessionctx.Context, args []E
if err != nil {
return nil, err
}
bf.tp.SetFlen(1)
sig := &BuiltinGroupingImplSig{bf, 0, map[uint64]struct{}{}, false}
// grouping(x,y,z) is a singed UInt64 (while MySQL is Int64 which is unreasonable)
bf.tp.SetFlag(bf.tp.GetFlag() | mysql.UnsignedFlag)
sig := &BuiltinGroupingImplSig{bf, 0, []map[uint64]struct{}{}, false}
sig.setPbCode(tipb.ScalarFuncSig_GroupingSig)
return sig, nil
}
Expand All @@ -60,12 +62,12 @@ type BuiltinGroupingImplSig struct {

// TODO these are two temporary fields for tests
mode tipb.GroupingMode
groupingMarks map[uint64]struct{}
groupingMarks []map[uint64]struct{}
isMetaInited bool
}

// SetMetadata will fill grouping function with comparison groupingMarks when rewriting grouping function.
func (b *BuiltinGroupingImplSig) SetMetadata(mode tipb.GroupingMode, groupingMarks map[uint64]struct{}) error {
func (b *BuiltinGroupingImplSig) SetMetadata(mode tipb.GroupingMode, groupingMarks []map[uint64]struct{}) error {
b.setGroupingMode(mode)
b.setMetaGroupingMarks(groupingMarks)
b.isMetaInited = true
Expand All @@ -81,7 +83,7 @@ func (b *BuiltinGroupingImplSig) setGroupingMode(mode tipb.GroupingMode) {
b.mode = mode
}

func (b *BuiltinGroupingImplSig) setMetaGroupingMarks(groupingMarks map[uint64]struct{}) {
func (b *BuiltinGroupingImplSig) setMetaGroupingMarks(groupingMarks []map[uint64]struct{}) {
b.groupingMarks = groupingMarks
}

Expand All @@ -96,9 +98,15 @@ func (b *BuiltinGroupingImplSig) metadata() proto.Message {
return &tipb.GroupingFunctionMetadata{}
}
args := &tipb.GroupingFunctionMetadata{}
*(args.Mode) = b.mode
for groupingMark := range b.groupingMarks {
args.GroupingMarks = append(args.GroupingMarks, groupingMark)
args.Mode = &b.mode
for _, groupingMark := range b.groupingMarks {
gm := &tipb.GroupingMark{
GroupingNums: make([]uint64, 0, len(groupingMark)),
}
for k := range groupingMark {
gm.GroupingNums = append(gm.GroupingNums, k)
}
args.GroupingMarks = append(args.GroupingMarks, gm)
}
return args
}
Expand All @@ -112,62 +120,97 @@ func (b *BuiltinGroupingImplSig) Clone() builtinFunc {
return newSig
}

func (b *BuiltinGroupingImplSig) getMetaGroupingMarks() map[uint64]struct{} {
func (b *BuiltinGroupingImplSig) getMetaGroupingMarks() []map[uint64]struct{} {
return b.groupingMarks
}

func (b *BuiltinGroupingImplSig) getMetaGroupingID() uint64 {
var metaGroupingID uint64
groupingIDs := b.getMetaGroupingMarks()
for key := range groupingIDs {
metaGroupingID = key
}
return metaGroupingID
}

func (b *BuiltinGroupingImplSig) checkMetadata() error {
if !b.isMetaInited {
return errors.Errorf("Meta data hasn't been initialized")
}
mode := b.getGroupingMode()
groupingIDs := b.getMetaGroupingMarks()
groupingMarks := b.getMetaGroupingMarks()
if mode != tipb.GroupingMode_ModeBitAnd && mode != tipb.GroupingMode_ModeNumericCmp && mode != tipb.GroupingMode_ModeNumericSet {
return errors.Errorf("Mode of meta data in grouping function is invalid. input mode: %d", mode)
} else if (mode == tipb.GroupingMode_ModeBitAnd || mode == tipb.GroupingMode_ModeNumericCmp) && len(groupingIDs) != 1 {
return errors.Errorf("Invalid number of groupingID. mode: %d, number of groupingID: %d", mode, len(b.groupingMarks))
} else if mode == tipb.GroupingMode_ModeBitAnd || mode == tipb.GroupingMode_ModeNumericCmp {
for _, groupingMark := range groupingMarks {
if len(groupingMark) != 1 {
return errors.Errorf("Invalid number of groupingID. mode: %d, number of groupingID: %d", mode, len(b.groupingMarks))
}
}
}
return nil
}

func (b *BuiltinGroupingImplSig) groupingImplBitAnd(groupingID uint64, metaGroupingID uint64) int64 {
if groupingID&metaGroupingID > 0 {
return 1
func (b *BuiltinGroupingImplSig) groupingImplBitAnd(groupingID uint64) int64 {
groupingMarks := b.getMetaGroupingMarks()
res := uint64(0)
for _, groupingMark := range groupingMarks {
// for Bit-And mode, there is only one element in groupingMark.
for k := range groupingMark {
res <<= 1
if groupingID&k <= 0 {
// col is not needed, being filled with null and grouped. = 1
res += 1
}
// col is needed in this grouping set, meaning not being grouped. = 0
}
}
return 0
return int64(res)
}

func (b *BuiltinGroupingImplSig) groupingImplNumericCmp(groupingID uint64, metaGroupingID uint64) int64 {
if groupingID > metaGroupingID {
return 1
func (b *BuiltinGroupingImplSig) groupingImplNumericCmp(groupingID uint64) int64 {
groupingMarks := b.getMetaGroupingMarks()
res := uint64(0)
for _, groupingMark := range groupingMarks {
// for Num-Cmp mode, there is only one element in groupingMark.
for k := range groupingMark {
res <<= 1
if groupingID <= k {
// col is not needed, being filled with null and grouped. = 1
res += 1
}
// col is needed, meaning not being grouped. = 0
}
}
return 0
return int64(res)
}

func (b *BuiltinGroupingImplSig) groupingImplNumericSet(groupingID uint64) int64 {
groupingIDs := b.getMetaGroupingMarks()
_, ok := groupingIDs[groupingID]
if ok {
return 0
groupingMarks := b.getMetaGroupingMarks()
res := uint64(0)
for _, groupingMark := range groupingMarks {
res <<= 1
// for Num-Set mode, traverse the slice to find the match.
_, ok := groupingMark[groupingID]
if !ok {
// in Num-Set mode, this map maintains the needed-col's grouping set (GIDs)
// when ok is NOT true, col is not needed, being filled with null and grouped. = 1
res += 1
}
// it means col is needed, meaning not being filled with null and grouped. = 0
}
return 1
return int64(res)
}

// since grouping function may have multi args like grouping(a,b), so the source columns may greater than 1.
// reference: https://dev.mysql.com/blog-archive/mysql-8-0-grouping-function/
// Let's say GROUPING(b,a) group by a,b with rollup. (Note the b,a sequence is reversed from gby item)
// if GROUPING (b,a) returns 3 (11 in bits), it means that NULL in column “b” and NULL in column “a” for that
// row is produced by a ROLLUP operation. If result is 2 (10 in bits), meaning NULL in column “a” alone is the
// result of ROLLUP operation.
//
// Formula: GROUPING(x,y,z) = GROUPING(x) << 2 + GROUPING(y) << 1 + GROUPING(z)
//
// so for the multi args GROUPING FUNCTION, we should return all the simple col grouping marks. When evaluating,
// after all grouping marks & with gid in sequence, the final res is derived as the formula said. This also means
// that the grouping function accepts a maximum of 64 parameters, obviously the result is an uint64.
func (b *BuiltinGroupingImplSig) grouping(groupingID uint64) int64 {
switch b.mode {
case tipb.GroupingMode_ModeBitAnd:
return b.groupingImplBitAnd(groupingID, b.getMetaGroupingID())
return b.groupingImplBitAnd(groupingID)
case tipb.GroupingMode_ModeNumericCmp:
return b.groupingImplNumericCmp(groupingID, b.getMetaGroupingID())
return b.groupingImplNumericCmp(groupingID)
case tipb.GroupingMode_ModeNumericSet:
return b.groupingImplNumericSet(groupingID)
}
Expand All @@ -193,14 +236,12 @@ func (b *BuiltinGroupingImplSig) groupingVec(groupingIds *chunk.Column, rowNum i
resContainer := result.Int64s()
switch b.mode {
case tipb.GroupingMode_ModeBitAnd:
metaGroupingID := b.getMetaGroupingID()
for i := 0; i < rowNum; i++ {
resContainer[i] = b.groupingImplBitAnd(groupingIds.GetUint64(i), metaGroupingID)
resContainer[i] = b.groupingImplBitAnd(groupingIds.GetUint64(i))
}
case tipb.GroupingMode_ModeNumericCmp:
metaGroupingID := b.getMetaGroupingID()
for i := 0; i < rowNum; i++ {
resContainer[i] = b.groupingImplNumericCmp(groupingIds.GetUint64(i), metaGroupingID)
resContainer[i] = b.groupingImplNumericCmp(groupingIds.GetUint64(i))
}
case tipb.GroupingMode_ModeNumericSet:
for i := 0; i < rowNum; i++ {
Expand Down
34 changes: 17 additions & 17 deletions expression/builtin_grouping_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ func createGroupingFunc(ctx sessionctx.Context, args []Expression) (*BuiltinGrou
return nil, err
}
bf.tp.SetFlen(1)
sig := &BuiltinGroupingImplSig{bf, 0, map[uint64]struct{}{}, false}
sig := &BuiltinGroupingImplSig{bf, 0, []map[uint64]struct{}{}, false}
sig.setPbCode(tipb.ScalarFuncSig_GroupingSig)
return sig, nil
}
Expand All @@ -63,23 +63,23 @@ func TestGrouping(t *testing.T) {
expectResult uint64
}{
// GroupingMode_ModeBitAnd
{1, 1, map[uint64]struct{}{1: {}}, 1},
{1, 1, map[uint64]struct{}{3: {}}, 1},
{1, 1, map[uint64]struct{}{6: {}}, 0},
{2, 1, map[uint64]struct{}{1: {}}, 0},
{2, 1, map[uint64]struct{}{3: {}}, 1},
{2, 1, map[uint64]struct{}{6: {}}, 1},
{4, 1, map[uint64]struct{}{2: {}}, 0},
{4, 1, map[uint64]struct{}{4: {}}, 1},
{4, 1, map[uint64]struct{}{6: {}}, 1},
{1, 1, map[uint64]struct{}{1: {}}, 0},
{1, 1, map[uint64]struct{}{3: {}}, 0},
{1, 1, map[uint64]struct{}{6: {}}, 1},
{2, 1, map[uint64]struct{}{1: {}}, 1},
{2, 1, map[uint64]struct{}{3: {}}, 0},
{2, 1, map[uint64]struct{}{6: {}}, 0},
{4, 1, map[uint64]struct{}{2: {}}, 1},
{4, 1, map[uint64]struct{}{4: {}}, 0},
{4, 1, map[uint64]struct{}{6: {}}, 0},

// GroupingMode_ModeNumericCmp
{0, 2, map[uint64]struct{}{0: {}}, 0},
{0, 2, map[uint64]struct{}{2: {}}, 0},
{2, 2, map[uint64]struct{}{0: {}}, 1},
{2, 2, map[uint64]struct{}{1: {}}, 1},
{2, 2, map[uint64]struct{}{2: {}}, 0},
{2, 2, map[uint64]struct{}{3: {}}, 0},
{0, 2, map[uint64]struct{}{0: {}}, 1},
{0, 2, map[uint64]struct{}{2: {}}, 1},
{2, 2, map[uint64]struct{}{0: {}}, 0},
{2, 2, map[uint64]struct{}{1: {}}, 0},
{2, 2, map[uint64]struct{}{2: {}}, 1},
{2, 2, map[uint64]struct{}{3: {}}, 1},

// GroupingMode_ModeNumericSet
{1, 3, map[uint64]struct{}{1: {}, 2: {}}, 0},
Expand All @@ -95,7 +95,7 @@ func TestGrouping(t *testing.T) {
groupingFunc, err := createGroupingFunc(ctx, args)
require.NoError(t, err, comment)

err = groupingFunc.SetMetadata(testCase.mode, testCase.groupingIDs)
err = groupingFunc.SetMetadata(testCase.mode, []map[uint64]struct{}{testCase.groupingIDs})
require.NoError(t, err, comment)

actualResult, err := evalBuiltinFunc(groupingFunc, chunk.Row{})
Expand Down
2 changes: 1 addition & 1 deletion go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ require (
github.com/pingcap/log v1.1.1-0.20230317032135-a0d097d16e22
github.com/pingcap/sysutil v1.0.1-0.20230407040306-fb007c5aff21
github.com/pingcap/tidb/parser v0.0.0-20211011031125-9b13dc409c5e
github.com/pingcap/tipb v0.0.0-20230523034258-1bbc3bbbd369
github.com/pingcap/tipb v0.0.0-20230602100112-acb7942db1ca
github.com/pkg/errors v0.9.1
github.com/prometheus/client_golang v1.15.1
github.com/prometheus/client_model v0.4.0
Expand Down
4 changes: 2 additions & 2 deletions go.sum
Original file line number Diff line number Diff line change
Expand Up @@ -798,8 +798,8 @@ github.com/pingcap/log v1.1.1-0.20230317032135-a0d097d16e22 h1:2SOzvGvE8beiC1Y4g
github.com/pingcap/log v1.1.1-0.20230317032135-a0d097d16e22/go.mod h1:DWQW5jICDR7UJh4HtxXSM20Churx4CQL0fwL/SoOSA4=
github.com/pingcap/sysutil v1.0.1-0.20230407040306-fb007c5aff21 h1:QV6jqlfOkh8hqvEAgwBZa+4bSgO0EeKC7s5c6Luam2I=
github.com/pingcap/sysutil v1.0.1-0.20230407040306-fb007c5aff21/go.mod h1:QYnjfA95ZaMefyl1NO8oPtKeb8pYUdnDVhQgf+qdpjM=
github.com/pingcap/tipb v0.0.0-20230523034258-1bbc3bbbd369 h1:f0c37nxxOl7C40+mC5bO9+IbVf8ia1frMU/WD0Heo4E=
github.com/pingcap/tipb v0.0.0-20230523034258-1bbc3bbbd369/go.mod h1:A7mrd7WHBl1o63LE2bIBGEJMTNWXqhgmYiOvMLxozfs=
github.com/pingcap/tipb v0.0.0-20230602100112-acb7942db1ca h1:J2HQyR5v1AcoBzx5/AYJW9XFSIl6si6YoC6yGI1W89c=
github.com/pingcap/tipb v0.0.0-20230602100112-acb7942db1ca/go.mod h1:A7mrd7WHBl1o63LE2bIBGEJMTNWXqhgmYiOvMLxozfs=
github.com/pkg/browser v0.0.0-20180916011732-0a3d74bf9ce4 h1:49lOXmGaUpV9Fz3gd7TFZY106KVlPVa5jcYD1gaQf98=
github.com/pkg/browser v0.0.0-20180916011732-0a3d74bf9ce4/go.mod h1:4OwLy04Bl9Ef3GJJCoec+30X3LQs/0/m4HFRt/2LUSA=
github.com/pkg/diff v0.0.0-20210226163009-20ebb0f2a09e/go.mod h1:pJLUxLENpZxwdsKMEsNbx1VGcRFpLqf3715MtcvvzbA=
Expand Down
1 change: 1 addition & 0 deletions parser/ast/functions.go
Original file line number Diff line number Diff line change
Expand Up @@ -302,6 +302,7 @@ const (
TiDBRowChecksum = "tidb_row_checksum"
GetLock = "get_lock"
ReleaseLock = "release_lock"
Grouping = "grouping"

// encryption and compression functions
AesDecrypt = "aes_decrypt"
Expand Down

0 comments on commit aedbcd0

Please sign in to comment.