diff --git a/expression/builtin_other.go b/expression/builtin_other.go index f86363dd7715a..78f6bfdb2cc96 100644 --- a/expression/builtin_other.go +++ b/expression/builtin_other.go @@ -511,6 +511,9 @@ func (b *builtinValuesIntSig) Clone() builtinFunc { // evalInt evals a builtinValuesIntSig. // See https://dev.mysql.com/doc/refman/5.7/en/miscellaneous-functions.html#function_values func (b *builtinValuesIntSig) evalInt(_ types.Row) (int64, bool, error) { + if !b.ctx.GetSessionVars().StmtCtx.InInsertStmt { + return 0, true, nil + } values := b.ctx.GetSessionVars().CurrInsertValues if values == nil { return 0, true, errors.New("Session current insert values is nil") @@ -540,6 +543,9 @@ func (b *builtinValuesRealSig) Clone() builtinFunc { // evalReal evals a builtinValuesRealSig. // See https://dev.mysql.com/doc/refman/5.7/en/miscellaneous-functions.html#function_values func (b *builtinValuesRealSig) evalReal(_ types.Row) (float64, bool, error) { + if !b.ctx.GetSessionVars().StmtCtx.InInsertStmt { + return 0, true, nil + } values := b.ctx.GetSessionVars().CurrInsertValues if values == nil { return 0, true, errors.New("Session current insert values is nil") @@ -569,6 +575,9 @@ func (b *builtinValuesDecimalSig) Clone() builtinFunc { // evalDecimal evals a builtinValuesDecimalSig. // See https://dev.mysql.com/doc/refman/5.7/en/miscellaneous-functions.html#function_values func (b *builtinValuesDecimalSig) evalDecimal(_ types.Row) (*types.MyDecimal, bool, error) { + if !b.ctx.GetSessionVars().StmtCtx.InInsertStmt { + return nil, true, nil + } values := b.ctx.GetSessionVars().CurrInsertValues if values == nil { return nil, true, errors.New("Session current insert values is nil") @@ -598,6 +607,9 @@ func (b *builtinValuesStringSig) Clone() builtinFunc { // evalString evals a builtinValuesStringSig. // See https://dev.mysql.com/doc/refman/5.7/en/miscellaneous-functions.html#function_values func (b *builtinValuesStringSig) evalString(_ types.Row) (string, bool, error) { + if !b.ctx.GetSessionVars().StmtCtx.InInsertStmt { + return "", true, nil + } values := b.ctx.GetSessionVars().CurrInsertValues if values == nil { return "", true, errors.New("Session current insert values is nil") @@ -627,6 +639,9 @@ func (b *builtinValuesTimeSig) Clone() builtinFunc { // evalTime evals a builtinValuesTimeSig. // See https://dev.mysql.com/doc/refman/5.7/en/miscellaneous-functions.html#function_values func (b *builtinValuesTimeSig) evalTime(_ types.Row) (types.Time, bool, error) { + if !b.ctx.GetSessionVars().StmtCtx.InInsertStmt { + return types.Time{}, true, nil + } values := b.ctx.GetSessionVars().CurrInsertValues if values == nil { return types.Time{}, true, errors.New("Session current insert values is nil") @@ -656,6 +671,9 @@ func (b *builtinValuesDurationSig) Clone() builtinFunc { // evalDuration evals a builtinValuesDurationSig. // See https://dev.mysql.com/doc/refman/5.7/en/miscellaneous-functions.html#function_values func (b *builtinValuesDurationSig) evalDuration(_ types.Row) (types.Duration, bool, error) { + if !b.ctx.GetSessionVars().StmtCtx.InInsertStmt { + return types.Duration{}, true, nil + } values := b.ctx.GetSessionVars().CurrInsertValues if values == nil { return types.Duration{}, true, errors.New("Session current insert values is nil") @@ -685,6 +703,9 @@ func (b *builtinValuesJSONSig) Clone() builtinFunc { // evalJSON evals a builtinValuesJSONSig. // See https://dev.mysql.com/doc/refman/5.7/en/miscellaneous-functions.html#function_values func (b *builtinValuesJSONSig) evalJSON(_ types.Row) (json.BinaryJSON, bool, error) { + if !b.ctx.GetSessionVars().StmtCtx.InInsertStmt { + return json.BinaryJSON{}, true, nil + } values := b.ctx.GetSessionVars().CurrInsertValues if values == nil { return json.BinaryJSON{}, true, errors.New("Session current insert values is nil") diff --git a/expression/builtin_other_test.go b/expression/builtin_other_test.go index 8bb26c89cbbd6..c498a56f8d976 100644 --- a/expression/builtin_other_test.go +++ b/expression/builtin_other_test.go @@ -14,7 +14,6 @@ package expression import ( - "fmt" "math" "time" @@ -24,6 +23,7 @@ import ( "github.com/pingcap/tidb/sessionctx/stmtctx" "github.com/pingcap/tidb/types" "github.com/pingcap/tidb/types/json" + "github.com/pingcap/tidb/util/chunk" "github.com/pingcap/tidb/util/hack" "github.com/pingcap/tidb/util/testleak" ) @@ -201,20 +201,35 @@ func (s *testEvaluatorSuite) TestGetVar(c *C) { func (s *testEvaluatorSuite) TestValues(c *C) { defer testleak.AfterTest(c)() + + origin := s.ctx.GetSessionVars().StmtCtx.InInsertStmt + s.ctx.GetSessionVars().StmtCtx.InInsertStmt = false + defer func() { + s.ctx.GetSessionVars().StmtCtx.InInsertStmt = origin + }() + fc := &valuesFunctionClass{baseFunctionClass{ast.Values, 0, 0}, 1, types.NewFieldType(mysql.TypeVarchar)} _, err := fc.getFunction(s.ctx, s.datumsToConstants(types.MakeDatums(""))) c.Assert(err, ErrorMatches, "*Incorrect parameter count in the call to native function 'values'") + sig, err := fc.getFunction(s.ctx, s.datumsToConstants(types.MakeDatums())) c.Assert(err, IsNil) - _, err = evalBuiltinFunc(sig, nil) - c.Assert(err.Error(), Equals, "Session current insert values is nil") - s.ctx.GetSessionVars().CurrInsertValues = types.DatumRow(types.MakeDatums("1")) - _, err = evalBuiltinFunc(sig, nil) - c.Assert(err.Error(), Equals, fmt.Sprintf("Session current insert values len %d and column's offset %v don't match", 1, 1)) + + ret, err := evalBuiltinFunc(sig, chunk.Row{}) + c.Assert(err, IsNil) + c.Assert(ret.IsNull(), IsTrue) + + s.ctx.GetSessionVars().CurrInsertValues = chunk.MutRowFromDatums(types.MakeDatums("1")).ToRow() + ret, err = evalBuiltinFunc(sig, chunk.Row{}) + c.Assert(err, IsNil) + c.Assert(ret.IsNull(), IsTrue) + currInsertValues := types.MakeDatums("1", "2") - s.ctx.GetSessionVars().CurrInsertValues = types.DatumRow(currInsertValues) - ret, err := evalBuiltinFunc(sig, nil) + s.ctx.GetSessionVars().StmtCtx.InInsertStmt = true + s.ctx.GetSessionVars().CurrInsertValues = chunk.MutRowFromDatums(currInsertValues).ToRow() + ret, err = evalBuiltinFunc(sig, chunk.Row{}) c.Assert(err, IsNil) + cmp, err := ret.CompareDatum(nil, &currInsertValues[1]) c.Assert(err, IsNil) c.Assert(cmp, Equals, 0) diff --git a/expression/integration_test.go b/expression/integration_test.go index 37f845af846c8..3e7027ba0f591 100644 --- a/expression/integration_test.go +++ b/expression/integration_test.go @@ -3394,3 +3394,13 @@ func (s *testIntegrationSuite) TestDecimalMul(c *C) { res := tk.MustQuery("select * from t;") res.Check(testkit.Rows("0.55125221922461136")) } + +func (s *testIntegrationSuite) TestValuesInNonInsertStmt(c *C) { + tk := testkit.NewTestKit(c, s.store) + tk.MustExec(`use test;`) + tk.MustExec(`drop table if exists t;`) + tk.MustExec(`create table t(a bigint, b double, c decimal, d varchar(20), e datetime, f time, g json);`) + tk.MustExec(`insert into t values(1, 1.1, 2.2, "abc", "2018-10-24", NOW(), "12");`) + res := tk.MustQuery(`select values(a), values(b), values(c), values(d), values(e), values(f), values(g) from t;`) + res.Check(testkit.Rows(` `)) +}