diff --git a/expression/builtin_string.go b/expression/builtin_string.go index 9ad71278fec2..bec3a11d415f 100644 --- a/expression/builtin_string.go +++ b/expression/builtin_string.go @@ -134,6 +134,7 @@ var ( _ builtinFunc = &builtinBinSig{} _ builtinFunc = &builtinEltSig{} _ builtinFunc = &builtinExportSetSig{} + _ builtinFunc = &builtinFormatWithLocaleSig{} _ builtinFunc = &builtinFormatSig{} _ builtinFunc = &builtinFromBase64Sig{} _ builtinFunc = &builtinToBase64Sig{} @@ -2450,51 +2451,71 @@ func (c *formatFunctionClass) getFunction(ctx context.Context, args []Expression if err := c.verifyArgs(args); err != nil { return nil, errors.Trace(err) } - sig := &builtinFormatSig{newBaseBuiltinFunc(args, ctx)} + argTps := make([]evalTp, 2, 3) + argTps[0], argTps[1] = tpString, tpString + if len(args) == 3 { + argTps = append(argTps, tpString) + } + bf := newBaseBuiltinFuncWithTp(args, ctx, tpString, argTps...) + bf.tp.Flen = mysql.MaxBlobWidth + var sig builtinFunc + if len(args) == 3 { + sig = &builtinFormatWithLocaleSig{baseStringBuiltinFunc{bf}} + } else { + sig = &builtinFormatSig{baseStringBuiltinFunc{bf}} + } return sig.setSelf(sig), nil } -type builtinFormatSig struct { - baseBuiltinFunc +type builtinFormatWithLocaleSig struct { + baseStringBuiltinFunc } -// eval evals a builtinFormatSig. +// evalString evals FORMAT(X,D,locale). // See https://dev.mysql.com/doc/refman/5.6/en/string-functions.html#function_format -func (b *builtinFormatSig) eval(row []types.Datum) (d types.Datum, err error) { - args, err := b.evalArgs(row) - if err != nil { - return d, errors.Trace(err) - } - if args[0].IsNull() { - d.SetNull() - return +func (b *builtinFormatWithLocaleSig) evalString(row []types.Datum) (string, bool, error) { + sc := b.ctx.GetSessionVars().StmtCtx + + x, isNull, err := b.args[0].EvalString(row, sc) + if isNull || err != nil { + return "", true, errors.Trace(err) } - arg0, err := args[0].ToString() - if err != nil { - return d, errors.Trace(err) + + d, isNull, err := b.args[1].EvalString(row, sc) + if isNull || err != nil { + return "", true, errors.Trace(err) } - arg1, err := args[1].ToString() - if err != nil { - return d, errors.Trace(err) + + locale, isNull, err := b.args[2].EvalString(row, sc) + if isNull || err != nil { + return "", true, errors.Trace(err) } - var arg2 string - if len(args) == 2 { - arg2 = "en_US" - } else if len(args) == 3 { - arg2, err = args[2].ToString() - if err != nil { - return d, errors.Trace(err) - } + formatString, err := mysql.GetLocaleFormatFunction(locale)(x, d) + return formatString, err != nil, errors.Trace(err) +} + +type builtinFormatSig struct { + baseStringBuiltinFunc +} + +// evalString evals FORMAT(X,D). +// See https://dev.mysql.com/doc/refman/5.7/en/string-functions.html#function_format +func (b *builtinFormatSig) evalString(row []types.Datum) (string, bool, error) { + sc := b.ctx.GetSessionVars().StmtCtx + + x, isNull, err := b.args[0].EvalString(row, sc) + if isNull || err != nil { + return "", true, errors.Trace(err) } - formatString, err := mysql.GetLocaleFormatFunction(arg2)(arg0, arg1) - if err != nil { - return d, errors.Trace(err) + d, isNull, err := b.args[1].EvalString(row, sc) + if isNull || err != nil { + return "", true, errors.Trace(err) } - d.SetString(formatString) - return d, nil + formatString, err := mysql.GetLocaleFormatFunction("en_US")(x, d) + return formatString, err != nil, errors.Trace(err) } type fromBase64FunctionClass struct { diff --git a/expression/builtin_string_test.go b/expression/builtin_string_test.go index 9e5cf31a3370..66eae6b54bc0 100644 --- a/expression/builtin_string_test.go +++ b/expression/builtin_string_test.go @@ -1474,6 +1474,8 @@ func (s *testEvaluatorSuite) TestFormat(c *C) { fc := funcs[ast.Format] f, err := fc.getFunction(s.ctx, datumsToConstants(types.MakeDatums(tt.number, tt.precision, tt.locale))) c.Assert(err, IsNil) + c.Assert(f, NotNil) + c.Assert(f.canBeFolded(), IsTrue) r, err := f.eval(nil) c.Assert(err, IsNil) c.Assert(r, testutil.DatumEquals, types.NewDatum(tt.ret)) @@ -1483,6 +1485,8 @@ func (s *testEvaluatorSuite) TestFormat(c *C) { fc := funcs[ast.Format] f, err := fc.getFunction(s.ctx, datumsToConstants(types.MakeDatums(tt.number, tt.precision))) c.Assert(err, IsNil) + c.Assert(f, NotNil) + c.Assert(f.canBeFolded(), IsTrue) r, err := f.eval(nil) c.Assert(err, IsNil) c.Assert(r, testutil.DatumEquals, types.NewDatum(tt.ret)) @@ -1491,6 +1495,8 @@ func (s *testEvaluatorSuite) TestFormat(c *C) { fc2 := funcs[ast.Format] f2, err := fc2.getFunction(s.ctx, datumsToConstants(types.MakeDatums(formatTests2.number, formatTests2.precision, formatTests2.locale))) c.Assert(err, IsNil) + c.Assert(f2, NotNil) + c.Assert(f2.canBeFolded(), IsTrue) r2, err := f2.eval(nil) c.Assert(types.NewDatum(err), testutil.DatumEquals, types.NewDatum(errors.New("not implemented"))) c.Assert(r2, testutil.DatumEquals, types.NewDatum(formatTests2.ret)) @@ -1498,6 +1504,8 @@ func (s *testEvaluatorSuite) TestFormat(c *C) { fc3 := funcs[ast.Format] f3, err := fc3.getFunction(s.ctx, datumsToConstants(types.MakeDatums(formatTests3.number, formatTests3.precision, formatTests3.locale))) c.Assert(err, IsNil) + c.Assert(f3, NotNil) + c.Assert(f3.canBeFolded(), IsTrue) r3, err := f3.eval(nil) c.Assert(types.NewDatum(err), testutil.DatumEquals, types.NewDatum(errors.New("not support for the specific locale"))) c.Assert(r3, testutil.DatumEquals, types.NewDatum(formatTests3.ret)) diff --git a/expression/integration_test.go b/expression/integration_test.go index 94edca6c54dd..87d96754e5d5 100644 --- a/expression/integration_test.go +++ b/expression/integration_test.go @@ -834,6 +834,17 @@ func (s *testIntegrationSuite) TestStringBuiltin(c *C) { result.Check(testkit.Rows("'aaaa' '' '\"\"' '\n\n'")) result = tk.MustQuery(`select quote(0121), quote(0000), quote("中文"), quote(NULL);`) result.Check(testkit.Rows("'121' '0' '中文' ")) + + // for format + result = tk.MustQuery(`select format(12332.1, 4), format(12332.2, 0), format(12332.2, 2,'en_US');`) + result.Check(testkit.Rows("12,332.1000 12,332 12,332.20")) + result = tk.MustQuery(`select format(NULL, 4), format(12332.2, NULL);`) + result.Check(testkit.Rows(" ")) + rs, err := tk.Exec(`select format(12332.2, 2,'es_EC');`) + c.Assert(err, IsNil) + _, err = tidb.GetRows(rs) + c.Assert(err, NotNil) + c.Assert(err.Error(), Matches, "not support for the specific locale") } func (s *testIntegrationSuite) TestEncryptionBuiltin(c *C) { diff --git a/plan/typeinfer_test.go b/plan/typeinfer_test.go index cc20c27b0b15..1f24bee14bb9 100644 --- a/plan/typeinfer_test.go +++ b/plan/typeinfer_test.go @@ -425,6 +425,9 @@ func (s *testPlanSuite) createTestCase4StrFuncs() []typeInferTestCase { {"quote(c_bigint_d )", mysql.TypeVarString, charset.CharsetUTF8, 0, 42, types.UnspecifiedLength}, {"quote(c_float_d )", mysql.TypeVarString, charset.CharsetUTF8, 0, 26, types.UnspecifiedLength}, {"quote(c_double_d )", mysql.TypeVarString, charset.CharsetUTF8, 0, 46, types.UnspecifiedLength}, + + {"format(c_double_d, c_double_d)", mysql.TypeLongBlob, charset.CharsetUTF8, 0, mysql.MaxBlobWidth, types.UnspecifiedLength}, + {"format(c_double_d, c_double_d, c_binary)", mysql.TypeLongBlob, charset.CharsetUTF8, 0, mysql.MaxBlobWidth, types.UnspecifiedLength}, } }