From acff4b60ec93de1685f8504cb8a8da380e57f1ee Mon Sep 17 00:00:00 2001 From: Jian Zhang Date: Mon, 21 Aug 2017 11:31:22 +0800 Subject: [PATCH] expression: rewrite builtin function: FIND_IN_SET (#4247) --- expression/builtin_string.go | 51 +++++++++++++++---------------- expression/builtin_string_test.go | 3 +- expression/integration_test.go | 6 ++++ plan/typeinfer_test.go | 17 +++++++++++ 4 files changed, 49 insertions(+), 28 deletions(-) diff --git a/expression/builtin_string.go b/expression/builtin_string.go index f63395c34753..8005aaba98eb 100644 --- a/expression/builtin_string.go +++ b/expression/builtin_string.go @@ -1999,49 +1999,46 @@ func (c *findInSetFunctionClass) getFunction(args []Expression, ctx context.Cont if err := c.verifyArgs(args); err != nil { return nil, errors.Trace(err) } - sig := &builtinFindInSetSig{newBaseBuiltinFunc(args, ctx)} + bf, err := newBaseBuiltinFuncWithTp(args, ctx, tpInt, tpString, tpString) + if err != nil { + return nil, errors.Trace(err) + } + bf.tp.Flen = 3 + sig := &builtinFindInSetSig{baseIntBuiltinFunc{bf}} return sig.setSelf(sig), nil } type builtinFindInSetSig struct { - baseBuiltinFunc + baseIntBuiltinFunc } -// eval evals a builtinFindInSetSig. +// evalInt evals FIND_IN_SET(str,strlist). // See https://dev.mysql.com/doc/refman/5.7/en/string-functions.html#function_find-in-set // TODO: This function can be optimized by using bit arithmetic when the first argument is // a constant string and the second is a column of type SET. -func (b *builtinFindInSetSig) eval(row []types.Datum) (d types.Datum, err error) { - args, err := b.evalArgs(row) - if err != nil { - return types.Datum{}, errors.Trace(err) - } - // args[0] -> Str - // args[1] -> StrList - if args[0].IsNull() || args[1].IsNull() { - return - } +func (b *builtinFindInSetSig) evalInt(row []types.Datum) (int64, bool, error) { + sc := b.ctx.GetSessionVars().StmtCtx - str, err := args[0].ToString() - if err != nil { - return d, errors.Trace(err) + str, isNull, err := b.args[0].EvalString(row, sc) + if isNull || err != nil { + return 0, isNull, errors.Trace(err) } - strlst, err := args[1].ToString() - if err != nil { - return d, errors.Trace(err) + + strlist, isNull, err := b.args[1].EvalString(row, sc) + if isNull || err != nil { + return 0, isNull, errors.Trace(err) } - d.SetInt64(0) - if len(strlst) == 0 { - return + if len(strlist) == 0 { + return 0, false, nil } - for i, s := range strings.Split(strlst, ",") { - if s == str { - d.SetInt64(int64(i + 1)) - return + + for i, strInSet := range strings.Split(strlist, ",") { + if str == strInSet { + return int64(i + 1), false, nil } } - return + return 0, false, nil } type fieldFunctionClass struct { diff --git a/expression/builtin_string_test.go b/expression/builtin_string_test.go index 531e20355a39..a8b22b053cbc 100644 --- a/expression/builtin_string_test.go +++ b/expression/builtin_string_test.go @@ -1185,9 +1185,10 @@ func (s *testEvaluatorSuite) TestFindInSet(c *C) { fc := funcs[ast.FindInSet] f, err := fc.getFunction(datumsToConstants(types.MakeDatums(t.str, t.strlst)), s.ctx) c.Assert(err, IsNil) + c.Assert(f.isDeterministic(), IsTrue) r, err := f.eval(nil) c.Assert(err, IsNil) - c.Assert(r, testutil.DatumEquals, types.NewDatum(t.ret)) + c.Assert(r, testutil.DatumEquals, types.NewDatum(t.ret), Commentf("FindInSet(%s, %s)", t.str, t.strlst)) } } diff --git a/expression/integration_test.go b/expression/integration_test.go index 5c44f39ec1a0..cc7b91051afe 100644 --- a/expression/integration_test.go +++ b/expression/integration_test.go @@ -792,6 +792,12 @@ func (s *testIntegrationSuite) TestStringBuiltin(c *C) { result.Check(testkit.Rows("0 1777777777777777777777 1777777777777777777777 1777777777777777777777")) result = tk.MustQuery(`select oct(-1.9), oct(1.9), oct(-1), oct(1), oct(-9999999999999999999999999), oct(9999999999999999999999999);`) result.Check(testkit.Rows("1777777777777777777777 1 1777777777777777777777 1 1777777777777777777777 1777777777777777777777")) + + // for find_in_set + result = tk.MustQuery(`select find_in_set("", ""), find_in_set("", ","), find_in_set("中文", "字符串,中文"), find_in_set("b,", "a,b,c,d");`) + result.Check(testkit.Rows("0 1 2 0")) + result = tk.MustQuery(`select find_in_set(NULL, ""), find_in_set("", NULL), find_in_set(1, "2,3,1");`) + result.Check(testkit.Rows(" 3")) } func (s *testIntegrationSuite) TestEncryptionBuiltin(c *C) { diff --git a/plan/typeinfer_test.go b/plan/typeinfer_test.go index ed61a740c59f..cbe42e3230d0 100644 --- a/plan/typeinfer_test.go +++ b/plan/typeinfer_test.go @@ -328,6 +328,23 @@ func (s *testPlanSuite) createTestCase4StrFuncs() []typeInferTestCase { {"oct(c_blob )", mysql.TypeVarString, charset.CharsetUTF8, 0, 64, types.UnspecifiedLength}, {"oct(c_set )", mysql.TypeVarString, charset.CharsetUTF8, 0, 64, types.UnspecifiedLength}, {"oct(c_enum )", mysql.TypeVarString, charset.CharsetUTF8, 0, 64, types.UnspecifiedLength}, + + {"find_in_set(c_int , c_text)", mysql.TypeLonglong, charset.CharsetBin, mysql.BinaryFlag, 3, 0}, + {"find_in_set(c_bigint , c_text)", mysql.TypeLonglong, charset.CharsetBin, mysql.BinaryFlag, 3, 0}, + {"find_in_set(c_float , c_text)", mysql.TypeLonglong, charset.CharsetBin, mysql.BinaryFlag, 3, 0}, + {"find_in_set(c_double , c_text)", mysql.TypeLonglong, charset.CharsetBin, mysql.BinaryFlag, 3, 0}, + {"find_in_set(c_decimal , c_text)", mysql.TypeLonglong, charset.CharsetBin, mysql.BinaryFlag, 3, 0}, + {"find_in_set(c_datetime , c_text)", mysql.TypeLonglong, charset.CharsetBin, mysql.BinaryFlag, 3, 0}, + {"find_in_set(c_time , c_text)", mysql.TypeLonglong, charset.CharsetBin, mysql.BinaryFlag, 3, 0}, + {"find_in_set(c_timestamp, c_text)", mysql.TypeLonglong, charset.CharsetBin, mysql.BinaryFlag, 3, 0}, + {"find_in_set(c_char , c_text)", mysql.TypeLonglong, charset.CharsetBin, mysql.BinaryFlag, 3, 0}, + {"find_in_set(c_varchar , c_text)", mysql.TypeLonglong, charset.CharsetBin, mysql.BinaryFlag, 3, 0}, + {"find_in_set(c_text , c_text)", mysql.TypeLonglong, charset.CharsetBin, mysql.BinaryFlag, 3, 0}, + {"find_in_set(c_binary , c_text)", mysql.TypeLonglong, charset.CharsetBin, mysql.BinaryFlag, 3, 0}, + {"find_in_set(c_varbinary, c_text)", mysql.TypeLonglong, charset.CharsetBin, mysql.BinaryFlag, 3, 0}, + {"find_in_set(c_blob , c_text)", mysql.TypeLonglong, charset.CharsetBin, mysql.BinaryFlag, 3, 0}, + {"find_in_set(c_set , c_text)", mysql.TypeLonglong, charset.CharsetBin, mysql.BinaryFlag, 3, 0}, + {"find_in_set(c_enum , c_text)", mysql.TypeLonglong, charset.CharsetBin, mysql.BinaryFlag, 3, 0}, } }