Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

*: rewrite builtin function ROW && remove Datum.KindRow #4480

Merged
merged 13 commits into from
Sep 11, 2017
28 changes: 13 additions & 15 deletions expression/builtin_other.go
Original file line number Diff line number Diff line change
Expand Up @@ -47,29 +47,27 @@ type rowFunctionClass struct {
baseFunctionClass
}

func (c *rowFunctionClass) getFunction(ctx context.Context, args []Expression) (builtinFunc, error) {
if err := c.verifyArgs(args); err != nil {
func (c *rowFunctionClass) getFunction(ctx context.Context, args []Expression) (sig builtinFunc, err error) {
if err = c.verifyArgs(args); err != nil {
return nil, errors.Trace(err)
}
sig := &builtinRowSig{newBaseBuiltinFunc(args, ctx)}
argTps := make([]evalTp, len(args))
for i := range argTps {
argTps[i] = fieldTp2EvalTp(args[i].GetType())
}
bf := newBaseBuiltinFuncWithTp(args, ctx, tpString, argTps...)
bf.foldable = false
sig = &builtinRowSig{baseStringBuiltinFunc{bf}}
return sig.setSelf(sig), nil
}

type builtinRowSig struct {
baseBuiltinFunc
}

func (b *builtinRowSig) canBeFolded() bool {
return false
baseStringBuiltinFunc
}

func (b *builtinRowSig) eval(row []types.Datum) (d types.Datum, err error) {
args, err := b.evalArgs(row)
if err != nil {
return types.Datum{}, errors.Trace(err)
}
d.SetRow(args)
return
// rowFunc should always be flattened in expression rewrite phrase.
func (b *builtinRowSig) evalString(row []types.Datum) (string, bool, error) {
panic("builtinRowSig.evalString() should never be called.")
}

type setVarFunctionClass struct {
Expand Down
24 changes: 3 additions & 21 deletions expression/builtin_other_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -74,27 +74,9 @@ func (s *testEvaluatorSuite) TestBitCount(c *C) {
func (s *testEvaluatorSuite) TestRowFunc(c *C) {
defer testleak.AfterTest(c)()
fc := funcs[ast.RowFunc]
testCases := []struct {
args []interface{}
}{
{[]interface{}{nil, nil}},
{[]interface{}{1, 2}},
{[]interface{}{"1", 2}},
{[]interface{}{"1", 2, true}},
{[]interface{}{"1", nil, true}},
{[]interface{}{"1", nil, true, nil}},
{[]interface{}{"1", 1.2, true, 120}},
}
for _, tc := range testCases {
fn, err := fc.getFunction(s.ctx, datumsToConstants(types.MakeDatums(tc.args...)))
c.Assert(err, IsNil)
d, err := fn.eval(types.MakeDatums(tc.args...))
c.Assert(err, IsNil)
c.Assert(d.Kind(), Equals, types.KindRow)
cmp, err := types.EqualDatums(nil, d.GetRow(), types.MakeDatums(tc.args...))
c.Assert(err, IsNil)
c.Assert(cmp, Equals, true)
}
fn, err := fc.getFunction(s.ctx, datumsToConstants(types.MakeDatums([]interface{}{"1", 1.2, true, 120}...)))
c.Assert(err, IsNil)
c.Assert(fn.canBeFolded(), IsFalse)
}

func (s *testEvaluatorSuite) TestSetVar(c *C) {
Expand Down
2 changes: 0 additions & 2 deletions expression/util.go
Original file line number Diff line number Diff line change
Expand Up @@ -113,8 +113,6 @@ func kindToFieldType(kind byte) types.FieldType {
ft.Tp = mysql.TypeEnum
case types.KindMysqlSet:
ft.Tp = mysql.TypeSet
case types.KindRow:
ft.Tp = mysql.TypeVarString
case types.KindInterface:
ft.Tp = mysql.TypeVarString
case types.KindMysqlDecimal:
Expand Down
15 changes: 2 additions & 13 deletions plan/expression_rewriter.go
Original file line number Diff line number Diff line change
Expand Up @@ -117,19 +117,14 @@ func getRowLen(e expression.Expression) int {
if f, ok := e.(*expression.ScalarFunction); ok && f.FuncName.L == ast.RowFunc {
return len(f.GetArgs())
}
if c, ok := e.(*expression.Constant); ok && c.Value.Kind() == types.KindRow {
return len(c.Value.GetRow())
}
return 1
}

func getRowArg(e expression.Expression, idx int) expression.Expression {
if f, ok := e.(*expression.ScalarFunction); ok {
return f.GetArgs()[idx]
}
c, _ := e.(*expression.Constant)
d := c.Value.GetRow()[idx]
return &expression.Constant{Value: d, RetType: c.GetType()}
return nil
}

// popRowArg pops the first element and return the rest of row.
Expand All @@ -143,12 +138,6 @@ func popRowArg(ctx context.Context, e expression.Expression) (ret expression.Exp
ret, err = expression.NewFunction(ctx, f.FuncName.L, f.GetType(), args[1:]...)
return ret, errors.Trace(err)
}
c, _ := e.(*expression.Constant)
if getRowLen(c) == 2 {
ret = &expression.Constant{Value: c.Value.GetRow()[1], RetType: c.GetType()}
} else {
ret = &expression.Constant{Value: types.NewDatum(c.Value.GetRow()[1:]), RetType: c.GetType()}
}
return
}

Expand Down Expand Up @@ -881,7 +870,7 @@ func (er *expressionRewriter) isTrueToScalarFunc(v *ast.IsTruthExpr) {

// inToExpression converts in expression to a scalar function. The argument lLen means the length of in list.
// The argument not means if the expression is not in. The tp stands for the expression type, which is always bool.
// a in (b, c, d) will be rewritted as `(a = b) or (a = c) or (a = d)`.
// a in (b, c, d) will be rewritten as `(a = b) or (a = c) or (a = d)`.
func (er *expressionRewriter) inToExpression(lLen int, not bool, tp *types.FieldType) {
stkLen := len(er.ctxStack)
l := getRowLen(er.ctxStack[stkLen-lLen-1])
Expand Down
4 changes: 0 additions & 4 deletions util/types/compare_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -106,10 +106,6 @@ func (s *testCompareSuite) TestCompare(c *C) {
{[]byte("123"), 1234, -1},
{[]byte{}, nil, 1},

{[]interface{}{1, 2, 3}, []interface{}{1, 2, 3}, 0},
{[]interface{}{1, 3, 3}, []interface{}{1, 2, 3}, 1},
{[]interface{}{1, 2, 3}, []interface{}{2, 2, 3}, -1},

{NewBinaryLiteralFromUint(1, -1), 1, 0},
{NewBinaryLiteralFromUint(0x4D7953514C, -1), "MySQL", 0},
{NewBinaryLiteralFromUint(0, -1), uint64(10), -1},
Expand Down
48 changes: 5 additions & 43 deletions util/types/datum.go
Original file line number Diff line number Diff line change
Expand Up @@ -45,12 +45,11 @@ const (
KindMysqlBit byte = 11 // Used for BIT table column values.
KindMysqlSet byte = 12
KindMysqlTime byte = 13
KindRow byte = 14
KindInterface byte = 15
KindMinNotNull byte = 16
KindMaxValue byte = 17
KindRaw byte = 18
KindMysqlJSON byte = 19
KindInterface byte = 14
KindMinNotNull byte = 15
KindMaxValue byte = 16
KindRaw byte = 17
KindMysqlJSON byte = 18
)

// Datum is a data box holds different kind of data.
Expand Down Expand Up @@ -193,17 +192,6 @@ func (d *Datum) SetInterface(x interface{}) {
d.x = x
}

// GetRow gets row value.
func (d *Datum) GetRow() []Datum {
return d.x.([]Datum)
}

// SetRow sets row value.
func (d *Datum) SetRow(ds []Datum) {
d.k = KindRow
d.x = ds
}

// SetNull sets datum to nil.
func (d *Datum) SetNull() {
d.k = KindNull
Expand Down Expand Up @@ -391,11 +379,6 @@ func (d *Datum) SetValue(val interface{}) {
d.SetMysqlJSON(x)
case Time:
d.SetMysqlTime(x)
case []Datum:
d.SetRow(x)
case []interface{}:
ds := MakeDatums(x...)
d.SetRow(ds)
default:
d.SetInterface(x)
}
Expand Down Expand Up @@ -450,8 +433,6 @@ func (d *Datum) CompareDatum(sc *variable.StatementContext, ad Datum) (int, erro
return d.compareMysqlJSON(sc, ad.GetMysqlJSON())
case KindMysqlTime:
return d.compareMysqlTime(sc, ad.GetMysqlTime())
case KindRow:
return d.compareRow(sc, ad.GetRow())
default:
return 0, nil
}
Expand Down Expand Up @@ -643,25 +624,6 @@ func (d *Datum) compareMysqlTime(sc *variable.StatementContext, time Time) (int,
}
}

func (d *Datum) compareRow(sc *variable.StatementContext, row []Datum) (int, error) {
var dRow []Datum
if d.k == KindRow {
dRow = d.GetRow()
} else {
dRow = []Datum{*d}
}
for i := 0; i < len(row) && i < len(dRow); i++ {
cmp, err := dRow[i].CompareDatum(sc, row[i])
if err != nil {
return 0, err
}
if cmp != 0 {
return cmp, nil
}
}
return CompareInt64(int64(len(dRow)), int64(len(row))), nil
}

// ConvertTo converts a datum to the target field type.
func (d *Datum) ConvertTo(sc *variable.StatementContext, target *FieldType) (Datum, error) {
if d.k == KindNull {
Expand Down
2 changes: 1 addition & 1 deletion util/types/datum_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,7 @@ func (ts *testDatumSuite) TestEqualDatums(c *C) {
func testEqualDatums(c *C, a []interface{}, b []interface{}, same bool) {
sc := new(variable.StatementContext)
sc.IgnoreTruncate = true
res, err := EqualDatums(sc, MakeDatums(a), MakeDatums(b))
res, err := EqualDatums(sc, MakeDatums(a...), MakeDatums(b...))
c.Assert(err, IsNil)
c.Assert(res, Equals, same, Commentf("a: %v, b: %v", a, b))
}
Expand Down