diff --git a/expression/builtin.go b/expression/builtin.go index b6855cb33959b..b925e3ca9fb74 100644 --- a/expression/builtin.go +++ b/expression/builtin.go @@ -801,7 +801,7 @@ var funcs = map[string]functionClass{ ast.Field: &fieldFunctionClass{baseFunctionClass{ast.Field, 2, -1}}, ast.Format: &formatFunctionClass{baseFunctionClass{ast.Format, 2, 3}}, ast.FromBase64: &fromBase64FunctionClass{baseFunctionClass{ast.FromBase64, 1, 1}}, - ast.InsertFunc: &insertFuncFunctionClass{baseFunctionClass{ast.InsertFunc, 4, 4}}, + ast.InsertFunc: &insertFunctionClass{baseFunctionClass{ast.InsertFunc, 4, 4}}, ast.Instr: &instrFunctionClass{baseFunctionClass{ast.Instr, 2, 2}}, ast.Lcase: &lowerFunctionClass{baseFunctionClass{ast.Lcase, 1, 1}}, ast.Left: &leftFunctionClass{baseFunctionClass{ast.Left, 2, 2}}, diff --git a/expression/builtin_string.go b/expression/builtin_string.go index 67609d4fce239..682db7331ba72 100644 --- a/expression/builtin_string.go +++ b/expression/builtin_string.go @@ -78,7 +78,7 @@ var ( _ functionClass = &formatFunctionClass{} _ functionClass = &fromBase64FunctionClass{} _ functionClass = &toBase64FunctionClass{} - _ functionClass = &insertFuncFunctionClass{} + _ functionClass = &insertFunctionClass{} _ functionClass = &instrFunctionClass{} _ functionClass = &loadFileFunctionClass{} ) @@ -140,7 +140,8 @@ var ( _ builtinFunc = &builtinFormatSig{} _ builtinFunc = &builtinFromBase64Sig{} _ builtinFunc = &builtinToBase64Sig{} - _ builtinFunc = &builtinInsertFuncSig{} + _ builtinFunc = &builtinInsertBinarySig{} + _ builtinFunc = &builtinInsertSig{} _ builtinFunc = &builtinInstrSig{} _ builtinFunc = &builtinInstrBinarySig{} ) @@ -2701,70 +2702,103 @@ func splitToSubN(s string, n int) []string { return subs } -type insertFuncFunctionClass struct { +type insertFunctionClass struct { baseFunctionClass } -func (c *insertFuncFunctionClass) getFunction(ctx context.Context, args []Expression) (builtinFunc, error) { - if err := c.verifyArgs(args); err != nil { +func (c *insertFunctionClass) getFunction(ctx context.Context, args []Expression) (sig builtinFunc, err error) { + if err = c.verifyArgs(args); err != nil { return nil, errors.Trace(err) } - sig := &builtinInsertFuncSig{newBaseBuiltinFunc(args, ctx)} + bf := newBaseBuiltinFuncWithTp(args, ctx, tpString, tpString, tpInt, tpInt, tpString) + bf.tp.Flen = mysql.MaxBlobWidth + SetBinFlagOrBinStr(args[0].GetType(), bf.tp) + SetBinFlagOrBinStr(args[3].GetType(), bf.tp) + if types.IsBinaryStr(args[0].GetType()) { + sig = &builtinInsertBinarySig{baseStringBuiltinFunc{bf}} + } else { + sig = &builtinInsertSig{baseStringBuiltinFunc{bf}} + } return sig.setSelf(sig), nil } -type builtinInsertFuncSig struct { - baseBuiltinFunc +type builtinInsertBinarySig struct { + baseStringBuiltinFunc } -// eval evals a builtinInsertFuncSig. -// See https://dev.mysql.com/doc/refman/5.7/en/string-functions.html#function_insert -func (b *builtinInsertFuncSig) eval(row []types.Datum) (d types.Datum, err error) { - args, err := b.evalArgs(row) - if err != nil { - return d, errors.Trace(err) +// evalString evals INSERT(str,pos,len,newstr). +// See https://dev.mysql.com/doc/refman/5.6/en/string-functions.html#function_insert +func (b *builtinInsertBinarySig) evalString(row []types.Datum) (string, bool, error) { + sc := b.ctx.GetSessionVars().StmtCtx + + str, isNull, err := b.args[0].EvalString(row, sc) + if isNull || err != nil { + return "", true, errors.Trace(err) } + strLength := int64(len(str)) - // Returns NULL if any argument is NULL. - if args[0].IsNull() || args[1].IsNull() || args[2].IsNull() || args[3].IsNull() { - return + pos, isNull, err := b.args[1].EvalInt(row, sc) + if isNull || err != nil { + return "", true, errors.Trace(err) + } + if pos < 1 || pos > strLength { + return str, false, nil } - str0, err := args[0].ToString() - if err != nil { - return d, errors.Trace(err) + length, isNull, err := b.args[2].EvalInt(row, sc) + if isNull || err != nil { + return "", true, errors.Trace(err) } - str := []rune(str0) - strLen := len(str) - posInt64, err := args[1].ToInt64(b.ctx.GetSessionVars().StmtCtx) - if err != nil { - return d, errors.Trace(err) + newstr, isNull, err := b.args[3].EvalString(row, sc) + if isNull || err != nil { + return "", true, errors.Trace(err) } - pos := int(posInt64) - lenInt64, err := args[2].ToInt64(b.ctx.GetSessionVars().StmtCtx) - if err != nil { - return d, errors.Trace(err) + if length > strLength-pos+1 || length < 0 { + return str[0:pos-1] + newstr, false, nil } - length := int(lenInt64) + return str[0:pos-1] + newstr + str[pos+length-1:], false, nil +} - newstr, err := args[3].ToString() - if err != nil { - return d, errors.Trace(err) +type builtinInsertSig struct { + baseStringBuiltinFunc +} + +// evalString evals INSERT(str,pos,len,newstr). +// See https://dev.mysql.com/doc/refman/5.6/en/string-functions.html#function_insert +func (b *builtinInsertSig) evalString(row []types.Datum) (string, bool, error) { + sc := b.ctx.GetSessionVars().StmtCtx + + str, isNull, err := b.args[0].EvalString(row, sc) + if isNull || err != nil { + return "", true, errors.Trace(err) } + runes := []rune(str) + runeLength := int64(len(runes)) - var s string - if pos < 1 || pos > strLen { - s = str0 - } else if length > strLen-pos+1 || length < 0 { - s = string(str[0:pos-1]) + newstr - } else { - s = string(str[0:pos-1]) + newstr + string(str[pos+length-1:]) + pos, isNull, err := b.args[1].EvalInt(row, sc) + if isNull || err != nil { + return "", true, errors.Trace(err) + } + if pos < 1 || pos > runeLength { + return str, false, nil } - d.SetString(s) - return d, nil + length, isNull, err := b.args[2].EvalInt(row, sc) + if isNull || err != nil { + return "", true, errors.Trace(err) + } + + newstr, isNull, err := b.args[3].EvalString(row, sc) + if isNull || err != nil { + return "", true, errors.Trace(err) + } + + if length > runeLength-pos+1 || length < 0 { + return string(runes[0:pos-1]) + newstr, false, nil + } + return string(runes[0:pos-1]) + newstr + string(runes[pos+length-1:]), false, nil } type instrFunctionClass struct { diff --git a/expression/builtin_string_test.go b/expression/builtin_string_test.go index 4c2f471d2fd99..e20b58a1ef7c7 100644 --- a/expression/builtin_string_test.go +++ b/expression/builtin_string_test.go @@ -1584,6 +1584,8 @@ func (s *testEvaluatorSuite) TestInsert(c *C) { for _, test := range tests { f, err := fc.getFunction(s.ctx, datumsToConstants(types.MakeDatums(test.args...))) c.Assert(err, IsNil) + c.Assert(f, NotNil) + c.Assert(f.canBeFolded(), IsTrue) result, err := f.eval(nil) c.Assert(err, IsNil) if test.expect == nil { diff --git a/expression/integration_test.go b/expression/integration_test.go index 82af750dac347..6d0a87cc41fb2 100644 --- a/expression/integration_test.go +++ b/expression/integration_test.go @@ -835,6 +835,12 @@ func (s *testIntegrationSuite) TestStringBuiltin(c *C) { result = tk.MustQuery(`select quote(0121), quote(0000), quote("中文"), quote(NULL);`) result.Check(testkit.Rows("'121' '0' '中文' ")) + // for insert + result = tk.MustQuery(`select insert("中文", 1, 1, cast("aaa" as binary)), insert("ba", -1, 1, "aaa"), insert("ba", 1, 100, "aaa"), insert("ba", 100, 1, "aaa");`) + result.Check(testkit.Rows("aaa文 ba aaa ba")) + result = tk.MustQuery(`select insert("bb", NULL, 1, "aa"), insert("bb", 1, NULL, "aa"), insert(NULL, 1, 1, "aaa"), insert("bb", 1, 1, NULL);`) + result.Check(testkit.Rows(" ")) + // for export_set result = tk.MustQuery(`select export_set(7, "1", "0", ",", 65);`) result.Check(testkit.Rows("1,1,1,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0")) diff --git a/plan/typeinfer_test.go b/plan/typeinfer_test.go index b2667ca8ade34..e21327bbda474 100644 --- a/plan/typeinfer_test.go +++ b/plan/typeinfer_test.go @@ -434,6 +434,11 @@ func (s *testPlanSuite) createTestCase4StrFuncs() []typeInferTestCase { {"quote(c_float_d )", mysql.TypeVarString, charset.CharsetUTF8, 0, 26, types.UnspecifiedLength}, {"quote(c_double_d )", mysql.TypeVarString, charset.CharsetUTF8, 0, 46, types.UnspecifiedLength}, + {"insert(c_varchar, c_int_d, c_int_d, c_varchar)", mysql.TypeLongBlob, charset.CharsetUTF8, 0, mysql.MaxBlobWidth, types.UnspecifiedLength}, + {"insert(c_varchar, c_int_d, c_int_d, c_binary)", mysql.TypeLongBlob, charset.CharsetBin, mysql.BinaryFlag, mysql.MaxBlobWidth, types.UnspecifiedLength}, + {"insert(c_binary, c_int_d, c_int_d, c_varchar)", mysql.TypeLongBlob, charset.CharsetBin, mysql.BinaryFlag, mysql.MaxBlobWidth, types.UnspecifiedLength}, + {"insert(c_binary, c_int_d, c_int_d, c_binary)", mysql.TypeLongBlob, charset.CharsetBin, mysql.BinaryFlag, mysql.MaxBlobWidth, types.UnspecifiedLength}, + {"export_set(c_double_d, c_text_d, c_text_d)", mysql.TypeLongBlob, charset.CharsetUTF8, 0, mysql.MaxBlobWidth, types.UnspecifiedLength}, {"export_set(c_double_d, c_text_d, c_text_d, c_text_d)", mysql.TypeLongBlob, charset.CharsetUTF8, 0, mysql.MaxBlobWidth, types.UnspecifiedLength}, {"export_set(c_double_d, c_text_d, c_text_d, c_text_d, c_int_d)", mysql.TypeLongBlob, charset.CharsetUTF8, 0, mysql.MaxBlobWidth, types.UnspecifiedLength},