Skip to content

Commit

Permalink
expression: rewrite builtin function: FIND_IN_SET (#4247)
Browse files Browse the repository at this point in the history
  • Loading branch information
zz-jason committed Aug 21, 2017
1 parent 31edc09 commit acff4b6
Show file tree
Hide file tree
Showing 4 changed files with 49 additions and 28 deletions.
51 changes: 24 additions & 27 deletions expression/builtin_string.go
Expand Up @@ -1999,49 +1999,46 @@ func (c *findInSetFunctionClass) getFunction(args []Expression, ctx context.Cont
if err := c.verifyArgs(args); err != nil {
return nil, errors.Trace(err)
}
sig := &builtinFindInSetSig{newBaseBuiltinFunc(args, ctx)}
bf, err := newBaseBuiltinFuncWithTp(args, ctx, tpInt, tpString, tpString)
if err != nil {
return nil, errors.Trace(err)
}
bf.tp.Flen = 3
sig := &builtinFindInSetSig{baseIntBuiltinFunc{bf}}
return sig.setSelf(sig), nil
}

type builtinFindInSetSig struct {
baseBuiltinFunc
baseIntBuiltinFunc
}

// eval evals a builtinFindInSetSig.
// evalInt evals FIND_IN_SET(str,strlist).
// See https://dev.mysql.com/doc/refman/5.7/en/string-functions.html#function_find-in-set
// TODO: This function can be optimized by using bit arithmetic when the first argument is
// a constant string and the second is a column of type SET.
func (b *builtinFindInSetSig) eval(row []types.Datum) (d types.Datum, err error) {
args, err := b.evalArgs(row)
if err != nil {
return types.Datum{}, errors.Trace(err)
}
// args[0] -> Str
// args[1] -> StrList
if args[0].IsNull() || args[1].IsNull() {
return
}
func (b *builtinFindInSetSig) evalInt(row []types.Datum) (int64, bool, error) {
sc := b.ctx.GetSessionVars().StmtCtx

str, err := args[0].ToString()
if err != nil {
return d, errors.Trace(err)
str, isNull, err := b.args[0].EvalString(row, sc)
if isNull || err != nil {
return 0, isNull, errors.Trace(err)
}
strlst, err := args[1].ToString()
if err != nil {
return d, errors.Trace(err)

strlist, isNull, err := b.args[1].EvalString(row, sc)
if isNull || err != nil {
return 0, isNull, errors.Trace(err)
}

d.SetInt64(0)
if len(strlst) == 0 {
return
if len(strlist) == 0 {
return 0, false, nil
}
for i, s := range strings.Split(strlst, ",") {
if s == str {
d.SetInt64(int64(i + 1))
return

for i, strInSet := range strings.Split(strlist, ",") {
if str == strInSet {
return int64(i + 1), false, nil
}
}
return
return 0, false, nil
}

type fieldFunctionClass struct {
Expand Down
3 changes: 2 additions & 1 deletion expression/builtin_string_test.go
Expand Up @@ -1185,9 +1185,10 @@ func (s *testEvaluatorSuite) TestFindInSet(c *C) {
fc := funcs[ast.FindInSet]
f, err := fc.getFunction(datumsToConstants(types.MakeDatums(t.str, t.strlst)), s.ctx)
c.Assert(err, IsNil)
c.Assert(f.isDeterministic(), IsTrue)
r, err := f.eval(nil)
c.Assert(err, IsNil)
c.Assert(r, testutil.DatumEquals, types.NewDatum(t.ret))
c.Assert(r, testutil.DatumEquals, types.NewDatum(t.ret), Commentf("FindInSet(%s, %s)", t.str, t.strlst))
}
}

Expand Down
6 changes: 6 additions & 0 deletions expression/integration_test.go
Expand Up @@ -792,6 +792,12 @@ func (s *testIntegrationSuite) TestStringBuiltin(c *C) {
result.Check(testkit.Rows("0 1777777777777777777777 1777777777777777777777 1777777777777777777777"))
result = tk.MustQuery(`select oct(-1.9), oct(1.9), oct(-1), oct(1), oct(-9999999999999999999999999), oct(9999999999999999999999999);`)
result.Check(testkit.Rows("1777777777777777777777 1 1777777777777777777777 1 1777777777777777777777 1777777777777777777777"))

// for find_in_set
result = tk.MustQuery(`select find_in_set("", ""), find_in_set("", ","), find_in_set("中文", "字符串,中文"), find_in_set("b,", "a,b,c,d");`)
result.Check(testkit.Rows("0 1 2 0"))
result = tk.MustQuery(`select find_in_set(NULL, ""), find_in_set("", NULL), find_in_set(1, "2,3,1");`)
result.Check(testkit.Rows("<nil> <nil> 3"))
}

func (s *testIntegrationSuite) TestEncryptionBuiltin(c *C) {
Expand Down
17 changes: 17 additions & 0 deletions plan/typeinfer_test.go
Expand Up @@ -328,6 +328,23 @@ func (s *testPlanSuite) createTestCase4StrFuncs() []typeInferTestCase {
{"oct(c_blob )", mysql.TypeVarString, charset.CharsetUTF8, 0, 64, types.UnspecifiedLength},
{"oct(c_set )", mysql.TypeVarString, charset.CharsetUTF8, 0, 64, types.UnspecifiedLength},
{"oct(c_enum )", mysql.TypeVarString, charset.CharsetUTF8, 0, 64, types.UnspecifiedLength},

{"find_in_set(c_int , c_text)", mysql.TypeLonglong, charset.CharsetBin, mysql.BinaryFlag, 3, 0},
{"find_in_set(c_bigint , c_text)", mysql.TypeLonglong, charset.CharsetBin, mysql.BinaryFlag, 3, 0},
{"find_in_set(c_float , c_text)", mysql.TypeLonglong, charset.CharsetBin, mysql.BinaryFlag, 3, 0},
{"find_in_set(c_double , c_text)", mysql.TypeLonglong, charset.CharsetBin, mysql.BinaryFlag, 3, 0},
{"find_in_set(c_decimal , c_text)", mysql.TypeLonglong, charset.CharsetBin, mysql.BinaryFlag, 3, 0},
{"find_in_set(c_datetime , c_text)", mysql.TypeLonglong, charset.CharsetBin, mysql.BinaryFlag, 3, 0},
{"find_in_set(c_time , c_text)", mysql.TypeLonglong, charset.CharsetBin, mysql.BinaryFlag, 3, 0},
{"find_in_set(c_timestamp, c_text)", mysql.TypeLonglong, charset.CharsetBin, mysql.BinaryFlag, 3, 0},
{"find_in_set(c_char , c_text)", mysql.TypeLonglong, charset.CharsetBin, mysql.BinaryFlag, 3, 0},
{"find_in_set(c_varchar , c_text)", mysql.TypeLonglong, charset.CharsetBin, mysql.BinaryFlag, 3, 0},
{"find_in_set(c_text , c_text)", mysql.TypeLonglong, charset.CharsetBin, mysql.BinaryFlag, 3, 0},
{"find_in_set(c_binary , c_text)", mysql.TypeLonglong, charset.CharsetBin, mysql.BinaryFlag, 3, 0},
{"find_in_set(c_varbinary, c_text)", mysql.TypeLonglong, charset.CharsetBin, mysql.BinaryFlag, 3, 0},
{"find_in_set(c_blob , c_text)", mysql.TypeLonglong, charset.CharsetBin, mysql.BinaryFlag, 3, 0},
{"find_in_set(c_set , c_text)", mysql.TypeLonglong, charset.CharsetBin, mysql.BinaryFlag, 3, 0},
{"find_in_set(c_enum , c_text)", mysql.TypeLonglong, charset.CharsetBin, mysql.BinaryFlag, 3, 0},
}
}

Expand Down

0 comments on commit acff4b6

Please sign in to comment.