Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

expression: rewrite builtin func values #4491

Merged
merged 5 commits into from
Sep 12, 2017
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
175 changes: 162 additions & 13 deletions expression/builtin_other.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ import (
"github.com/pingcap/tidb/context"
"github.com/pingcap/tidb/mysql"
"github.com/pingcap/tidb/util/types"
"github.com/pingcap/tidb/util/types/json"
)

var (
Expand All @@ -39,7 +40,13 @@ var (
_ builtinFunc = &builtinGetVarSig{}
_ builtinFunc = &builtinLockSig{}
_ builtinFunc = &builtinReleaseLockSig{}
_ builtinFunc = &builtinValuesSig{}
_ builtinFunc = &builtinValuesIntSig{}
_ builtinFunc = &builtinValuesRealSig{}
_ builtinFunc = &builtinValuesDecimalSig{}
_ builtinFunc = &builtinValuesStringSig{}
_ builtinFunc = &builtinValuesTimeSig{}
_ builtinFunc = &builtinValuesDurationSig{}
_ builtinFunc = &builtinValuesJSONSig{}
_ builtinFunc = &builtinBitCountSig{}
)

Expand Down Expand Up @@ -149,31 +156,173 @@ type valuesFunctionClass struct {
baseFunctionClass

offset int
tp *types.FieldType
}

func (c *valuesFunctionClass) getFunction(ctx context.Context, args []Expression) (builtinFunc, error) {
err := errors.Trace(c.verifyArgs(args))
bt := &builtinValuesSig{newBaseBuiltinFunc(args, ctx), c.offset}
bt.foldable = false
return bt.setSelf(bt), errors.Trace(err)
func (c *valuesFunctionClass) getFunction(ctx context.Context, args []Expression) (sig builtinFunc, err error) {
if err = errors.Trace(c.verifyArgs(args)); err != nil {
return nil, err
}
bf := newBaseBuiltinFunc(args, ctx)
bf.tp = c.tp
bf.foldable = false
switch fieldTp2EvalTp(c.tp) {
case tpInt:
sig = &builtinValuesIntSig{baseIntBuiltinFunc{bf}, c.offset}
case tpReal:
sig = &builtinValuesRealSig{baseRealBuiltinFunc{bf}, c.offset}
case tpDecimal:
sig = &builtinValuesRealSig{baseRealBuiltinFunc{bf}, c.offset}
case tpString:
sig = &builtinValuesStringSig{baseStringBuiltinFunc{bf}, c.offset}
case tpDatetime, tpTimestamp:
sig = &builtinValuesTimeSig{baseTimeBuiltinFunc{bf}, c.offset}
case tpDuration:
sig = &builtinValuesDurationSig{baseDurationBuiltinFunc{bf}, c.offset}
case tpJSON:
sig = &builtinValuesJSONSig{baseJSONBuiltinFunc{bf}, c.offset}
}
return sig.setSelf(sig), nil
}

type builtinValuesSig struct {
baseBuiltinFunc
type builtinValuesIntSig struct {
baseIntBuiltinFunc

offset int
}

// evalInt evals a builtinValuesIntSig.
// See https://dev.mysql.com/doc/refman/5.7/en/miscellaneous-functions.html#function_values
func (b *builtinValuesIntSig) evalInt(_ []types.Datum) (int64, bool, error) {
values := b.ctx.GetSessionVars().CurrInsertValues
if values == nil {
return 0, true, errors.New("Session current insert values is nil")
}
row := values.([]types.Datum)
if b.offset < len(row) {
return row[b.offset].GetInt64(), false, nil
}
return 0, true, errors.Errorf("Session current insert values len %d and column's offset %v don't match", len(row), b.offset)
}

type builtinValuesRealSig struct {
baseRealBuiltinFunc

offset int
}

// evalReal evals a builtinValuesRealSig.
// See https://dev.mysql.com/doc/refman/5.7/en/miscellaneous-functions.html#function_values
func (b *builtinValuesRealSig) evalReal(_ []types.Datum) (float64, bool, error) {
values := b.ctx.GetSessionVars().CurrInsertValues
if values == nil {
return 0, true, errors.New("Session current insert values is nil")
}
row := values.([]types.Datum)
if b.offset < len(row) {
return row[b.offset].GetFloat64(), false, nil
}
return 0, true, errors.Errorf("Session current insert values len %d and column's offset %v don't match", len(row), b.offset)
}

type builtinValuesDecimalSig struct {
baseDecimalBuiltinFunc

offset int
}

// evalDecimal evals a builtinValuesDecimalSig.
// See https://dev.mysql.com/doc/refman/5.7/en/miscellaneous-functions.html#function_values
func (b *builtinValuesDecimalSig) evalDecimal(_ []types.Datum) (*types.MyDecimal, bool, error) {
values := b.ctx.GetSessionVars().CurrInsertValues
if values == nil {
return nil, true, errors.New("Session current insert values is nil")
}
row := values.([]types.Datum)
if b.offset < len(row) {
return row[b.offset].GetMysqlDecimal(), false, nil
}
return nil, true, errors.Errorf("Session current insert values len %d and column's offset %v don't match", len(row), b.offset)
}

type builtinValuesStringSig struct {
baseStringBuiltinFunc

offset int
}

// evalString evals a builtinValuesStringSig.
// See https://dev.mysql.com/doc/refman/5.7/en/miscellaneous-functions.html#function_values
func (b *builtinValuesStringSig) evalString(_ []types.Datum) (string, bool, error) {
values := b.ctx.GetSessionVars().CurrInsertValues
if values == nil {
return "", true, errors.New("Session current insert values is nil")
}
row := values.([]types.Datum)
if b.offset < len(row) {
return row[b.offset].GetString(), false, nil
}
return "", true, errors.Errorf("Session current insert values len %d and column's offset %v don't match", len(row), b.offset)
}

type builtinValuesTimeSig struct {
baseTimeBuiltinFunc

offset int
}

// // evalTime evals a builtinValuesTimeSig.
// See https://dev.mysql.com/doc/refman/5.7/en/miscellaneous-functions.html#function_values
func (b *builtinValuesTimeSig) evalTime(_ []types.Datum) (types.Time, bool, error) {
values := b.ctx.GetSessionVars().CurrInsertValues
if values == nil {
return types.Time{}, true, errors.New("Session current insert values is nil")
}
row := values.([]types.Datum)
if b.offset < len(row) {
return row[b.offset].GetMysqlTime(), false, nil
}
return types.Time{}, true, errors.Errorf("Session current insert values len %d and column's offset %v don't match", len(row), b.offset)
}

type builtinValuesDurationSig struct {
baseDurationBuiltinFunc

offset int
}

// // evalDuration evals a builtinValuesDurationSig.
// See https://dev.mysql.com/doc/refman/5.7/en/miscellaneous-functions.html#function_values
func (b *builtinValuesDurationSig) evalDuration(_ []types.Datum) (types.Duration, bool, error) {
values := b.ctx.GetSessionVars().CurrInsertValues
if values == nil {
return types.Duration{}, true, errors.New("Session current insert values is nil")
}
row := values.([]types.Datum)
if b.offset < len(row) {
return row[b.offset].GetMysqlDuration(), false, nil
}
return types.Duration{}, true, errors.Errorf("Session current insert values len %d and column's offset %v don't match", len(row), b.offset)
}

type builtinValuesJSONSig struct {
baseJSONBuiltinFunc

offset int
}

func (b *builtinValuesSig) eval(_ []types.Datum) (types.Datum, error) {
// evalJSON evals a builtinValuesJSONSig.
// See https://dev.mysql.com/doc/refman/5.7/en/miscellaneous-functions.html#function_values
func (b *builtinValuesJSONSig) evalJSON(_ []types.Datum) (json.JSON, bool, error) {
values := b.ctx.GetSessionVars().CurrInsertValues
if values == nil {
return types.Datum{}, errors.New("Session current insert values is nil")
return json.JSON{}, true, errors.New("Session current insert values is nil")
}
row := values.([]types.Datum)
if len(row) > b.offset {
return row[b.offset], nil
if b.offset < len(row) {
return row[b.offset].GetMysqlJSON(), false, nil
}
return types.Datum{}, errors.Errorf("Session current insert values len %d and column's offset %v don't match", len(row), b.offset)
return json.JSON{}, true, errors.Errorf("Session current insert values len %d and column's offset %v don't match", len(row), b.offset)
}

type bitCountFunctionClass struct {
Expand Down
3 changes: 2 additions & 1 deletion expression/builtin_other_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ import (

. "github.com/pingcap/check"
"github.com/pingcap/tidb/ast"
"github.com/pingcap/tidb/mysql"
"github.com/pingcap/tidb/sessionctx/variable"
"github.com/pingcap/tidb/util/testleak"
"github.com/pingcap/tidb/util/types"
Expand Down Expand Up @@ -164,7 +165,7 @@ func (s *testEvaluatorSuite) TestGetVar(c *C) {

func (s *testEvaluatorSuite) TestValues(c *C) {
defer testleak.AfterTest(c)()
fc := &valuesFunctionClass{baseFunctionClass{ast.Values, 0, 0}, 1}
fc := &valuesFunctionClass{baseFunctionClass{ast.Values, 0, 0}, 1, types.NewFieldType(mysql.TypeVarchar)}
_, err := fc.getFunction(s.ctx, datumsToConstants(types.MakeDatums("")))
c.Assert(err, ErrorMatches, "*Incorrect parameter count in the call to native function 'values'")
sig, err := fc.getFunction(s.ctx, datumsToConstants(types.MakeDatums()))
Expand Down
2 changes: 1 addition & 1 deletion expression/expression.go
Original file line number Diff line number Diff line change
Expand Up @@ -589,7 +589,7 @@ func NewCastFunc(tp *types.FieldType, arg Expression, ctx context.Context) Expre

// NewValuesFunc creates a new values function.
func NewValuesFunc(offset int, retTp *types.FieldType, ctx context.Context) *ScalarFunction {
fc := &valuesFunctionClass{baseFunctionClass{ast.Values, 0, 0}, offset}
fc := &valuesFunctionClass{baseFunctionClass{ast.Values, 0, 0}, offset, retTp}
bt, _ := fc.getFunction(ctx, nil)
return &ScalarFunction{
FuncName: model.NewCIStr(ast.Values),
Expand Down
2 changes: 1 addition & 1 deletion expression/expression_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ func (s *testEvaluatorSuite) TestNewValuesFunc(c *C) {
res := NewValuesFunc(0, types.NewFieldType(mysql.TypeLonglong), s.ctx)
c.Assert(res.FuncName.O, Equals, "values")
c.Assert(res.RetType.Tp, Equals, mysql.TypeLonglong)
_, ok := res.Function.(*builtinValuesSig)
_, ok := res.Function.(*builtinValuesIntSig)
c.Assert(ok, IsTrue)
}

Expand Down
10 changes: 10 additions & 0 deletions expression/integration_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2392,6 +2392,16 @@ func (s *testIntegrationSuite) TestOtherBuiltin(c *C) {
result = tk.MustQuery(`select bit_count(121), bit_count(-1), bit_count(null), bit_count("1231aaa");`)
result.Check(testkit.Rows("5 64 <nil> 7"))

tk.MustExec("drop table if exists t")
tk.MustExec("create table t(a int primary key, b time, c double, d varchar(10))")
tk.MustExec(`insert into t values(1, '01:01:01', 1.1, "1"), (2, '02:02:02', 2.2, "2")`)
tk.MustExec(`insert into t(a, b) values(1, '12:12:12') on duplicate key update a = values(b)`)
result = tk.MustQuery(`select a from t order by a`)
result.Check(testkit.Rows("2", "121212"))
tk.MustExec(`insert into t values(2, '12:12:12', 1.1, "3.3") on duplicate key update a = values(c) + values(d)`)
result = tk.MustQuery(`select a from t order by a`)
result.Check(testkit.Rows("4", "121212"))

// for setvar, getvar
tk.MustExec(`set @varname = "Abc"`)
result = tk.MustQuery(`select @varname, @VARNAME`)
Expand Down
20 changes: 18 additions & 2 deletions expression/scalar_function.go
Original file line number Diff line number Diff line change
Expand Up @@ -115,8 +115,24 @@ func (sf *ScalarFunction) Clone() Expression {
case ast.Cast:
return buildCastFunction(sf.GetArgs()[0], sf.GetType(), sf.GetCtx())
case ast.Values:
v := sf.Function.(*builtinValuesSig)
return NewValuesFunc(v.offset, sf.GetType(), sf.GetCtx())
var offset int
switch fieldTp2EvalTp(sf.GetType()) {
case tpInt:
offset = sf.Function.(*builtinValuesIntSig).offset
case tpReal:
offset = sf.Function.(*builtinValuesRealSig).offset
case tpDecimal:
offset = sf.Function.(*builtinValuesDecimalSig).offset
case tpString:
offset = sf.Function.(*builtinValuesStringSig).offset
case tpDatetime, tpTimestamp:
offset = sf.Function.(*builtinValuesTimeSig).offset
case tpDuration:
offset = sf.Function.(*builtinValuesDurationSig).offset
case tpJSON:
offset = sf.Function.(*builtinValuesJSONSig).offset
}
return NewValuesFunc(offset, sf.GetType(), sf.GetCtx())
}
newFunc, _ := NewFunction(sf.GetCtx(), sf.FuncName.L, sf.RetType, newArgs...)
return newFunc
Expand Down
2 changes: 1 addition & 1 deletion expression/scalar_function_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ func (s *testEvaluatorSuite) TestScalarFunction(c *C) {
c.Assert(ok, IsTrue)
c.Assert(newSf.FuncName.O, Equals, "values")
c.Assert(newSf.RetType.Tp, Equals, mysql.TypeLonglong)
_, ok = newSf.Function.(*builtinValuesSig)
_, ok = newSf.Function.(*builtinValuesIntSig)
c.Assert(ok, IsTrue)
}

Expand Down