/
base_func.go
112 lines (103 loc) · 3.89 KB
/
base_func.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
package aggregation
import (
"strings"
"github.com/cznic/mathutil"
"github.com/pingcap/errors"
"github.com/pingcap/parser/ast"
"github.com/pingcap/parser/mysql"
"github.com/pingcap/tidb/types"
"github.com/zhihu/zetta/tablestore/mysql/expression"
"github.com/zhihu/zetta/tablestore/mysql/sctx"
)
// baseFuncDesc describes an function signature, only used in planner.
type baseFuncDesc struct {
// Name represents the function name.
Name string
// Args represents the arguments of the function.
Args []expression.Expression
// RetTp represents the return type of the function.
RetTp *types.FieldType
}
func newBaseFuncDesc(ctx sctx.Context, name string, args []expression.Expression) (baseFuncDesc, error) {
b := baseFuncDesc{Name: strings.ToLower(name), Args: args}
err := b.typeInfer(ctx)
return b, err
}
// typeInfer infers the arguments and return types of an function.
func (a *baseFuncDesc) typeInfer(ctx sctx.Context) error {
switch a.Name {
case ast.AggFuncCount:
a.typeInfer4Count()
case ast.AggFuncSum:
a.typeInfer4Sum()
case ast.AggFuncAvg:
a.typeInfer4Avg()
case ast.AggFuncMax, ast.AggFuncMin, ast.AggFuncFirstRow:
a.typeInfer4MaxMin()
default:
return errors.Errorf("unsupported agg function: %s", a.Name)
}
return nil
}
func (a *baseFuncDesc) typeInfer4Count() {
a.RetTp = types.NewFieldType(mysql.TypeLonglong)
a.RetTp.Flen = 21
a.RetTp.Decimal = 0
// count never returns null
a.RetTp.Flag |= mysql.NotNullFlag
types.SetBinChsClnFlag(a.RetTp)
}
// typeInfer4Sum should returns a "decimal", otherwise it returns a "double".
// Because child returns integer or decimal type.
func (a *baseFuncDesc) typeInfer4Sum() {
switch a.Args[0].GetType().Tp {
case mysql.TypeTiny, mysql.TypeShort, mysql.TypeInt24, mysql.TypeLong, mysql.TypeLonglong:
a.RetTp = types.NewFieldType(mysql.TypeNewDecimal)
a.RetTp.Flen, a.RetTp.Decimal = mysql.MaxDecimalWidth, 0
case mysql.TypeNewDecimal:
a.RetTp = types.NewFieldType(mysql.TypeNewDecimal)
a.RetTp.Flen, a.RetTp.Decimal = mysql.MaxDecimalWidth, a.Args[0].GetType().Decimal
if a.RetTp.Decimal < 0 || a.RetTp.Decimal > mysql.MaxDecimalScale {
a.RetTp.Decimal = mysql.MaxDecimalScale
}
case mysql.TypeDouble, mysql.TypeFloat:
a.RetTp = types.NewFieldType(mysql.TypeDouble)
a.RetTp.Flen, a.RetTp.Decimal = mysql.MaxRealWidth, a.Args[0].GetType().Decimal
default:
a.RetTp = types.NewFieldType(mysql.TypeDouble)
a.RetTp.Flen, a.RetTp.Decimal = mysql.MaxRealWidth, types.UnspecifiedLength
}
types.SetBinChsClnFlag(a.RetTp)
}
// typeInfer4Avg should returns a "decimal", otherwise it returns a "double".
// Because child returns integer or decimal type.
func (a *baseFuncDesc) typeInfer4Avg() {
switch a.Args[0].GetType().Tp {
case mysql.TypeTiny, mysql.TypeShort, mysql.TypeInt24, mysql.TypeLong, mysql.TypeLonglong, mysql.TypeNewDecimal:
a.RetTp = types.NewFieldType(mysql.TypeNewDecimal)
if a.Args[0].GetType().Decimal < 0 {
a.RetTp.Decimal = mysql.MaxDecimalScale
} else {
a.RetTp.Decimal = mathutil.Min(a.Args[0].GetType().Decimal+types.DivFracIncr, mysql.MaxDecimalScale)
}
a.RetTp.Flen = mysql.MaxDecimalWidth
case mysql.TypeDouble, mysql.TypeFloat:
a.RetTp = types.NewFieldType(mysql.TypeDouble)
a.RetTp.Flen, a.RetTp.Decimal = mysql.MaxRealWidth, a.Args[0].GetType().Decimal
default:
a.RetTp = types.NewFieldType(mysql.TypeDouble)
a.RetTp.Flen, a.RetTp.Decimal = mysql.MaxRealWidth, types.UnspecifiedLength
}
types.SetBinChsClnFlag(a.RetTp)
}
func (a *baseFuncDesc) typeInfer4MaxMin() {
a.RetTp = a.Args[0].GetType()
if (a.Name == ast.AggFuncMax || a.Name == ast.AggFuncMin) && a.RetTp.Tp != mysql.TypeBit {
a.RetTp = a.Args[0].GetType().Clone()
a.RetTp.Flag &^= mysql.NotNullFlag
}
// TODO: fix other aggFuncs for TypeEnum & TypeSet
if (a.RetTp.Tp == mysql.TypeEnum || a.RetTp.Tp == mysql.TypeSet) && a.Name != ast.AggFuncFirstRow {
a.RetTp = &types.FieldType{Tp: mysql.TypeString, Flen: mysql.MaxFieldCharLength}
}
}