Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

planner, expression: support builtin function `NAME_CONST` #9261

Merged
merged 13 commits into from Feb 19, 2019
@@ -72,6 +72,14 @@ var (
_ builtinFunc = &builtinIsIPv4MappedSig{}
_ builtinFunc = &builtinIsIPv6Sig{}
_ builtinFunc = &builtinUUIDSig{}

_ builtinFunc = &builtinNameConstIntSig{}
_ builtinFunc = &builtinNameConstRealSig{}
_ builtinFunc = &builtinNameConstDecimalSig{}
_ builtinFunc = &builtinNameConstTimeSig{}
_ builtinFunc = &builtinNameConstDurationSig{}
_ builtinFunc = &builtinNameConstStringSig{}
_ builtinFunc = &builtinNameConstJSONSig{}
)

type sleepFunctionClass struct {
@@ -228,7 +236,7 @@ func (c *anyValueFunctionClass) getFunction(ctx sessionctx.Context, args []Expre
bf.tp.Charset, bf.tp.Collate, bf.tp.Flag = mysql.DefaultCharset, mysql.DefaultCollationName, 0
sig = &builtinTimeAnyValueSig{bf}
default:
panic("unexpected types.EvalType of builtin function ANY_VALUE")
return nil, errIncorrectArgs.GenWithStackByArgs("ANY_VALUE")
}
return sig, nil
}
@@ -808,7 +816,133 @@ type nameConstFunctionClass struct {
}

func (c *nameConstFunctionClass) getFunction(ctx sessionctx.Context, args []Expression) (builtinFunc, error) {
return nil, errFunctionNotExists.GenWithStackByArgs("FUNCTION", "NAME_CONST")
if err := c.verifyArgs(args); err != nil {
return nil, err
}
argTp := args[1].GetType().EvalType()
bf := newBaseBuiltinFuncWithTp(ctx, args, argTp, types.ETString, argTp)
*bf.tp = *args[1].GetType()
var sig builtinFunc
switch argTp {
case types.ETDecimal:
sig = &builtinNameConstDecimalSig{bf}
case types.ETDuration:
sig = &builtinNameConstDurationSig{bf}
case types.ETInt:
bf.tp.Decimal = 0
sig = &builtinNameConstIntSig{bf}
case types.ETJson:
sig = &builtinNameConstJSONSig{bf}
case types.ETReal:
sig = &builtinNameConstRealSig{bf}
case types.ETString:
bf.tp.Decimal = types.UnspecifiedLength
sig = &builtinNameConstStringSig{bf}
case types.ETDatetime, types.ETTimestamp:
bf.tp.Charset, bf.tp.Collate, bf.tp.Flag = mysql.DefaultCharset, mysql.DefaultCollationName, 0
sig = &builtinNameConstTimeSig{bf}
default:
return nil, errIncorrectArgs.GenWithStackByArgs("NAME_CONST")
}
return sig, nil
}

type builtinNameConstDecimalSig struct {
baseBuiltinFunc
}

func (b *builtinNameConstDecimalSig) Clone() builtinFunc {
newSig := &builtinNameConstDecimalSig{}
newSig.cloneFrom(&b.baseBuiltinFunc)
return newSig
}

func (b *builtinNameConstDecimalSig) evalDecimal(row chunk.Row) (*types.MyDecimal, bool, error) {
return b.args[1].EvalDecimal(b.ctx, row)
}

type builtinNameConstIntSig struct {
baseBuiltinFunc
}

func (b *builtinNameConstIntSig) Clone() builtinFunc {
newSig := &builtinNameConstIntSig{}
newSig.cloneFrom(&b.baseBuiltinFunc)
return newSig
}

func (b *builtinNameConstIntSig) evalInt(row chunk.Row) (int64, bool, error) {
return b.args[1].EvalInt(b.ctx, row)
}

type builtinNameConstRealSig struct {
baseBuiltinFunc
}

func (b *builtinNameConstRealSig) Clone() builtinFunc {
newSig := &builtinNameConstRealSig{}
newSig.cloneFrom(&b.baseBuiltinFunc)
return newSig
}

func (b *builtinNameConstRealSig) evalReal(row chunk.Row) (float64, bool, error) {
return b.args[1].EvalReal(b.ctx, row)
}

type builtinNameConstStringSig struct {
baseBuiltinFunc
}

func (b *builtinNameConstStringSig) Clone() builtinFunc {
newSig := &builtinNameConstStringSig{}
newSig.cloneFrom(&b.baseBuiltinFunc)
return newSig
}

func (b *builtinNameConstStringSig) evalString(row chunk.Row) (string, bool, error) {
return b.args[1].EvalString(b.ctx, row)
}

type builtinNameConstJSONSig struct {
baseBuiltinFunc
}

func (b *builtinNameConstJSONSig) Clone() builtinFunc {
newSig := &builtinNameConstJSONSig{}
newSig.cloneFrom(&b.baseBuiltinFunc)
return newSig
}

func (b *builtinNameConstJSONSig) evalJSON(row chunk.Row) (json.BinaryJSON, bool, error) {
return b.args[1].EvalJSON(b.ctx, row)
}

type builtinNameConstDurationSig struct {
baseBuiltinFunc
}

func (b *builtinNameConstDurationSig) Clone() builtinFunc {
newSig := &builtinNameConstDurationSig{}
newSig.cloneFrom(&b.baseBuiltinFunc)
return newSig
}

func (b *builtinNameConstDurationSig) evalDuration(row chunk.Row) (types.Duration, bool, error) {
return b.args[1].EvalDuration(b.ctx, row)
}

type builtinNameConstTimeSig struct {
baseBuiltinFunc
}

func (b *builtinNameConstTimeSig) Clone() builtinFunc {
newSig := &builtinNameConstTimeSig{}
newSig.cloneFrom(&b.baseBuiltinFunc)
return newSig
}

func (b *builtinNameConstTimeSig) evalTime(row chunk.Row) (types.Time, bool, error) {
return b.args[1].EvalTime(b.ctx, row)
}

type releaseAllLocksFunctionClass struct {
@@ -15,9 +15,11 @@ package expression
import (
"math"
"strings"
"time"

. "github.com/pingcap/check"
"github.com/pingcap/parser/ast"
"github.com/pingcap/parser/mysql"
"github.com/pingcap/tidb/types"
"github.com/pingcap/tidb/util/chunk"
"github.com/pingcap/tidb/util/testleak"
@@ -320,3 +322,46 @@ func (s *testEvaluatorSuite) TestIsIPv4Compat(c *C) {
c.Assert(err, IsNil)
c.Assert(r, testutil.DatumEquals, types.NewDatum(0))
}

func (s *testEvaluatorSuite) TestNameConst(c *C) {
defer testleak.AfterTest(c)()
dec := types.NewDecFromFloatForTest(123.123)
tm := types.Time{Time: types.FromGoTime(time.Now()), Fsp: 6, Type: mysql.TypeDatetime}
du := types.Duration{Duration: time.Duration(12*time.Hour + 1*time.Minute + 1*time.Second), Fsp: types.DefaultFsp}
cases := []struct {
colName string
arg interface{}
isNil bool
asserts func(d types.Datum)
}{
{"test_int", 3, false, func(d types.Datum) {
c.Assert(d.GetInt64(), Equals, int64(3))
}},
{"test_float", 3.14159, false, func(d types.Datum) {
c.Assert(d.GetFloat64(), Equals, 3.14159)
}},
{"test_string", "TiDB", false, func(d types.Datum) {
c.Assert(d.GetString(), Equals, "TiDB")
}},
{"test_null", nil, true, func(d types.Datum) {
c.Assert(d.Kind(), Equals, types.KindNull)
}},
{"test_decimal", dec, false, func(d types.Datum) {
c.Assert(d.GetMysqlDecimal().String(), Equals, dec.String())
}},
{"test_time", tm, false, func(d types.Datum) {
c.Assert(d.GetMysqlTime().String(), Equals, tm.String())
}},
{"test_duration", du, false, func(d types.Datum) {
c.Assert(d.GetMysqlDuration().String(), Equals, du.String())
}},
}

for _, t := range cases {
f, err := newFunctionForTest(s.ctx, ast.NameConst, s.primitiveValsToConstants([]interface{}{t.colName, t.arg})...)
c.Assert(err, IsNil)
d, err := f.Eval(chunk.Row{})
c.Assert(err, IsNil)
t.asserts(d)
}
}
@@ -37,6 +37,7 @@ import (
"github.com/pingcap/tidb/table"
"github.com/pingcap/tidb/types"
"github.com/pingcap/tidb/util/mock"
"github.com/pingcap/tidb/util/sqlexec"
"github.com/pingcap/tidb/util/testkit"
"github.com/pingcap/tidb/util/testleak"
"github.com/pingcap/tidb/util/testutil"
@@ -3935,6 +3936,42 @@ func (s *testIntegrationSuite) TestValuesFloat32(c *C) {
tk.MustQuery(`select * from t;`).Check(testkit.Rows(`1 0.02`))
}

func (s *testIntegrationSuite) TestFuncNameConst(c *C) {
tk := testkit.NewTestKit(c, s.store)
defer s.cleanEnv(c)
tk.MustExec("USE test;")
tk.MustExec("DROP TABLE IF EXISTS t;")
tk.MustExec("CREATE TABLE t(a CHAR(20), b VARCHAR(20), c BIGINT);")
tk.MustExec("INSERT INTO t (b, c) values('hello', 1);")

r := tk.MustQuery("SELECT name_const('test_int', 1), name_const('test_float', 3.1415);")
r.Check(testkit.Rows("1 3.1415"))
r = tk.MustQuery("SELECT name_const('test_string', 'hello'), name_const('test_nil', null);")
r.Check(testkit.Rows("hello <nil>"))
r = tk.MustQuery("SELECT name_const('test_string', 1) + c FROM t;")
r.Check(testkit.Rows("2"))
r = tk.MustQuery("SELECT concat('hello', name_const('test_string', 'world')) FROM t;")
r.Check(testkit.Rows("helloworld"))
err := tk.ExecToErr(`select name_const(a,b) from t;`)
c.Assert(err.Error(), Equals, "[planner:1210]Incorrect arguments to NAME_CONST")
err = tk.ExecToErr(`select name_const(a,"hello") from t;`)
c.Assert(err.Error(), Equals, "[planner:1210]Incorrect arguments to NAME_CONST")
err = tk.ExecToErr(`select name_const("hello", b) from t;`)
c.Assert(err.Error(), Equals, "[planner:1210]Incorrect arguments to NAME_CONST")
err = tk.ExecToErr(`select name_const("hello", 1+1) from t;`)
c.Assert(err.Error(), Equals, "[planner:1210]Incorrect arguments to NAME_CONST")
err = tk.ExecToErr(`select name_const(concat('a', 'b'), 555) from t;`)
c.Assert(err.Error(), Equals, "[planner:1210]Incorrect arguments to NAME_CONST")
err = tk.ExecToErr(`select name_const(555) from t;`)
c.Assert(err.Error(), Equals, "[expression:1582]Incorrect parameter count in the call to native function 'name_const'")

var rs sqlexec.RecordSet
rs, err = tk.Exec(`select name_const("hello", 1);`)
c.Assert(err, IsNil)
c.Assert(len(rs.Fields()), Equals, 1)
c.Assert(rs.Fields()[0].Column.Name.L, Equals, "hello")
}

func (s *testIntegrationSuite) TestValuesEnum(c *C) {
tk := testkit.NewTestKit(c, s.store)
tk.MustExec("use test")
@@ -558,18 +558,29 @@ func (b *PlanBuilder) buildProjectionFieldNameFromColumns(field *ast.SelectField
}

// buildProjectionFieldNameFromExpressions builds the field name when field expression is a normal expression.
func (b *PlanBuilder) buildProjectionFieldNameFromExpressions(field *ast.SelectField) model.CIStr {
func (b *PlanBuilder) buildProjectionFieldNameFromExpressions(field *ast.SelectField) (model.CIStr, error) {
if agg, ok := field.Expr.(*ast.AggregateFuncExpr); ok && agg.F == ast.AggFuncFirstRow {
// When the query is select t.a from t group by a; The Column Name should be a but not t.a;
return agg.Args[0].(*ast.ColumnNameExpr).Name.Name
return agg.Args[0].(*ast.ColumnNameExpr).Name.Name, nil
}

innerExpr := getInnerFromParenthesesAndUnaryPlus(field.Expr)
funcCall, isFuncCall := innerExpr.(*ast.FuncCallExpr)
// When used to produce a result set column, NAME_CONST() causes the column to have the given name.
// See https://dev.mysql.com/doc/refman/5.7/en/miscellaneous-functions.html#function_name-const for details
if isFuncCall && funcCall.FnName.L == ast.NameConst {
if v, err := evalAstExpr(b.ctx, funcCall.Args[0]); err == nil {
This conversation was marked as resolved by zz-jason

This comment has been minimized.

Copy link
@eurekaka

eurekaka Feb 11, 2019

Contributor

Return error also if err is not nil?

This comment has been minimized.

Copy link
@spongedu

spongedu Feb 11, 2019

Author Contributor

yes, I make a mistake here. line 575 should be moved outside the curly brace in line 576.

This comment has been minimized.

Copy link
@zz-jason

zz-jason Feb 14, 2019

Member

seems we should directly return error as long as the first parameter is not a constant:

MySQL(root@localhost:test) > select name_const(1+1, 1) from t;
ERROR 1210 (HY000): Incorrect arguments to NAME_CONST

This comment has been minimized.

Copy link
@zz-jason

zz-jason Feb 14, 2019

Member

It is guaranteed aster preprocess: https://github.com/pingcap/tidb/pull/9261/files#diff-c95e0584bc651e7f623fb47d69f690bcR143. So we don't need to call evalAstExpr() to get the constant value.

This comment has been minimized.

Copy link
@zz-jason

zz-jason Feb 15, 2019

Member

@spongedu PTAL this comment. Could you abandon the function call of evalAstExpr()?

This comment has been minimized.

Copy link
@spongedu

spongedu Feb 15, 2019

Author Contributor

@zz-jason for name_const, we should get the first argument to rename the output column, so I think evalAstExpr is still needed?

if s, err := v.ToString(); err == nil {
return model.NewCIStr(s), nil
}
}
return model.NewCIStr(""), ErrWrongArguments.GenWithStackByArgs("NAME_CONST")
}
valueExpr, isValueExpr := innerExpr.(*driver.ValueExpr)

// Non-literal: Output as inputed, except that comments need to be removed.
if !isValueExpr {
return model.NewCIStr(parser.SpecFieldPattern.ReplaceAllStringFunc(field.Text(), parser.TrimComment))
return model.NewCIStr(parser.SpecFieldPattern.ReplaceAllStringFunc(field.Text(), parser.TrimComment)), nil
}

// Literal: Need special processing
@@ -585,21 +596,21 @@ func (b *PlanBuilder) buildProjectionFieldNameFromExpressions(field *ast.SelectF
fieldName := strings.TrimLeftFunc(projName, func(r rune) bool {
return !unicode.IsOneOf(mysql.RangeGraph, r)
})
return model.NewCIStr(fieldName)
return model.NewCIStr(fieldName), nil
case types.KindNull:
// See #4053, #3685
return model.NewCIStr("NULL")
return model.NewCIStr("NULL"), nil
default:
// Keep as it is.
if innerExpr.Text() != "" {
return model.NewCIStr(innerExpr.Text())
return model.NewCIStr(innerExpr.Text()), nil
}
return model.NewCIStr(field.Text())
return model.NewCIStr(field.Text()), nil
}
}

// buildProjectionField builds the field object according to SelectField in projection.
func (b *PlanBuilder) buildProjectionField(id, position int, field *ast.SelectField, expr expression.Expression) *expression.Column {
func (b *PlanBuilder) buildProjectionField(id, position int, field *ast.SelectField, expr expression.Expression) (*expression.Column, error) {
var origTblName, tblName, origColName, colName, dbName model.CIStr
if c, ok := expr.(*expression.Column); ok && !c.IsReferenced {
// Field is a column reference.
@@ -609,7 +620,10 @@ func (b *PlanBuilder) buildProjectionField(id, position int, field *ast.SelectFi
colName = field.AsName
} else {
// Other: field is an expression.
colName = b.buildProjectionFieldNameFromExpressions(field)
var err error
if colName, err = b.buildProjectionFieldNameFromExpressions(field); err != nil {
return nil, errors.Trace(err)
This conversation was marked as resolved by eurekaka

This comment has been minimized.

Copy link
@eurekaka

eurekaka Feb 11, 2019

Contributor

Once the error is with stack, we don't need errors.Trace either.

This comment has been minimized.

Copy link
@spongedu

spongedu Feb 11, 2019

Author Contributor

ok

}
}
return &expression.Column{
UniqueID: b.ctx.GetSessionVars().AllocPlanColumnID(),
@@ -619,7 +633,7 @@ func (b *PlanBuilder) buildProjectionField(id, position int, field *ast.SelectFi
OrigColName: origColName,
DBName: dbName,
RetType: expr.GetType(),
}
}, nil
}

// buildProjection returns a Projection plan and non-aux columns length.
@@ -648,7 +662,10 @@ func (b *PlanBuilder) buildProjection(p LogicalPlan, fields []*ast.SelectField,
expr = p.Schema().Columns[i]
}
proj.Exprs = append(proj.Exprs, expr)
col := b.buildProjectionField(proj.id, schema.Len()+1, field, expr)
col, err := b.buildProjectionField(proj.id, schema.Len()+1, field, expr)
if err != nil {
return nil, 0, errors.Trace(err)
This conversation was marked as resolved by spongedu

This comment has been minimized.

Copy link
@eurekaka

eurekaka Feb 11, 2019

Contributor

Ditto

}
schema.Append(col)
continue
}
@@ -660,7 +677,10 @@ func (b *PlanBuilder) buildProjection(p LogicalPlan, fields []*ast.SelectField,
p = np
proj.Exprs = append(proj.Exprs, newExpr)

col := b.buildProjectionField(proj.id, schema.Len()+1, field, newExpr)
col, err := b.buildProjectionField(proj.id, schema.Len()+1, field, newExpr)
if err != nil {
return nil, 0, errors.Trace(err)
This conversation was marked as resolved by eurekaka

This comment has been minimized.

Copy link
@eurekaka

eurekaka Feb 11, 2019

Contributor

Ditto

This comment has been minimized.

Copy link
@spongedu

spongedu Feb 11, 2019

Author Contributor

I think we'd better left errors.Trace because there may be more error returned by buildProjectionField not with stack in the future? I'm not sure.

}
schema.Append(col)
}
proj.SetSchema(schema)
ProTip! Use n and p to navigate between commits in a pull request.
You can’t perform that action at this time.