Skip to content

Commit

Permalink
expression: fix "values" function in non-insert statement (#8019) (#8169
Browse files Browse the repository at this point in the history
)
  • Loading branch information
zz-jason committed Nov 5, 2018
1 parent 2f7d4b9 commit 3be9756
Show file tree
Hide file tree
Showing 3 changed files with 51 additions and 6 deletions.
21 changes: 21 additions & 0 deletions expression/builtin_other.go
Expand Up @@ -512,6 +512,9 @@ func (b *builtinValuesIntSig) Clone() builtinFunc {
// evalInt evals a builtinValuesIntSig.
// See https://dev.mysql.com/doc/refman/5.7/en/miscellaneous-functions.html#function_values
func (b *builtinValuesIntSig) evalInt(_ chunk.Row) (int64, bool, error) {
if !b.ctx.GetSessionVars().StmtCtx.InInsertStmt {
return 0, true, nil
}
row := b.ctx.GetSessionVars().CurrInsertValues
if row.IsEmpty() {
return 0, true, errors.New("Session current insert values is nil")
Expand Down Expand Up @@ -540,6 +543,9 @@ func (b *builtinValuesRealSig) Clone() builtinFunc {
// evalReal evals a builtinValuesRealSig.
// See https://dev.mysql.com/doc/refman/5.7/en/miscellaneous-functions.html#function_values
func (b *builtinValuesRealSig) evalReal(_ chunk.Row) (float64, bool, error) {
if !b.ctx.GetSessionVars().StmtCtx.InInsertStmt {
return 0, true, nil
}
row := b.ctx.GetSessionVars().CurrInsertValues
if row.IsEmpty() {
return 0, true, errors.New("Session current insert values is nil")
Expand Down Expand Up @@ -568,6 +574,9 @@ func (b *builtinValuesDecimalSig) Clone() builtinFunc {
// evalDecimal evals a builtinValuesDecimalSig.
// See https://dev.mysql.com/doc/refman/5.7/en/miscellaneous-functions.html#function_values
func (b *builtinValuesDecimalSig) evalDecimal(_ chunk.Row) (*types.MyDecimal, bool, error) {
if !b.ctx.GetSessionVars().StmtCtx.InInsertStmt {
return nil, true, nil
}
row := b.ctx.GetSessionVars().CurrInsertValues
if row.IsEmpty() {
return nil, true, errors.New("Session current insert values is nil")
Expand Down Expand Up @@ -596,6 +605,9 @@ func (b *builtinValuesStringSig) Clone() builtinFunc {
// evalString evals a builtinValuesStringSig.
// See https://dev.mysql.com/doc/refman/5.7/en/miscellaneous-functions.html#function_values
func (b *builtinValuesStringSig) evalString(_ chunk.Row) (string, bool, error) {
if !b.ctx.GetSessionVars().StmtCtx.InInsertStmt {
return "", true, nil
}
row := b.ctx.GetSessionVars().CurrInsertValues
if row.IsEmpty() {
return "", true, errors.New("Session current insert values is nil")
Expand Down Expand Up @@ -624,6 +636,9 @@ func (b *builtinValuesTimeSig) Clone() builtinFunc {
// evalTime evals a builtinValuesTimeSig.
// See https://dev.mysql.com/doc/refman/5.7/en/miscellaneous-functions.html#function_values
func (b *builtinValuesTimeSig) evalTime(_ chunk.Row) (types.Time, bool, error) {
if !b.ctx.GetSessionVars().StmtCtx.InInsertStmt {
return types.Time{}, true, nil
}
row := b.ctx.GetSessionVars().CurrInsertValues
if row.IsEmpty() {
return types.Time{}, true, errors.New("Session current insert values is nil")
Expand Down Expand Up @@ -652,6 +667,9 @@ func (b *builtinValuesDurationSig) Clone() builtinFunc {
// evalDuration evals a builtinValuesDurationSig.
// See https://dev.mysql.com/doc/refman/5.7/en/miscellaneous-functions.html#function_values
func (b *builtinValuesDurationSig) evalDuration(_ chunk.Row) (types.Duration, bool, error) {
if !b.ctx.GetSessionVars().StmtCtx.InInsertStmt {
return types.Duration{}, true, nil
}
row := b.ctx.GetSessionVars().CurrInsertValues
if row.IsEmpty() {
return types.Duration{}, true, errors.New("Session current insert values is nil")
Expand Down Expand Up @@ -681,6 +699,9 @@ func (b *builtinValuesJSONSig) Clone() builtinFunc {
// evalJSON evals a builtinValuesJSONSig.
// See https://dev.mysql.com/doc/refman/5.7/en/miscellaneous-functions.html#function_values
func (b *builtinValuesJSONSig) evalJSON(_ chunk.Row) (json.BinaryJSON, bool, error) {
if !b.ctx.GetSessionVars().StmtCtx.InInsertStmt {
return json.BinaryJSON{}, true, nil
}
row := b.ctx.GetSessionVars().CurrInsertValues
if row.IsEmpty() {
return json.BinaryJSON{}, true, errors.New("Session current insert values is nil")
Expand Down
26 changes: 20 additions & 6 deletions expression/builtin_other_test.go
Expand Up @@ -14,7 +14,6 @@
package expression

import (
"fmt"
"math"
"time"

Expand Down Expand Up @@ -202,20 +201,35 @@ func (s *testEvaluatorSuite) TestGetVar(c *C) {

func (s *testEvaluatorSuite) TestValues(c *C) {
defer testleak.AfterTest(c)()

origin := s.ctx.GetSessionVars().StmtCtx.InInsertStmt
s.ctx.GetSessionVars().StmtCtx.InInsertStmt = false
defer func() {
s.ctx.GetSessionVars().StmtCtx.InInsertStmt = origin
}()

fc := &valuesFunctionClass{baseFunctionClass{ast.Values, 0, 0}, 1, types.NewFieldType(mysql.TypeVarchar)}
_, err := fc.getFunction(s.ctx, s.datumsToConstants(types.MakeDatums("")))
c.Assert(err, ErrorMatches, "*Incorrect parameter count in the call to native function 'values'")

sig, err := fc.getFunction(s.ctx, s.datumsToConstants(types.MakeDatums()))
c.Assert(err, IsNil)
_, err = evalBuiltinFunc(sig, chunk.Row{})
c.Assert(err.Error(), Equals, "Session current insert values is nil")

ret, err := evalBuiltinFunc(sig, chunk.Row{})
c.Assert(err, IsNil)
c.Assert(ret.IsNull(), IsTrue)

s.ctx.GetSessionVars().CurrInsertValues = chunk.MutRowFromDatums(types.MakeDatums("1")).ToRow()
_, err = evalBuiltinFunc(sig, chunk.Row{})
c.Assert(err.Error(), Equals, fmt.Sprintf("Session current insert values len %d and column's offset %v don't match", 1, 1))
ret, err = evalBuiltinFunc(sig, chunk.Row{})
c.Assert(err, IsNil)
c.Assert(ret.IsNull(), IsTrue)

currInsertValues := types.MakeDatums("1", "2")
s.ctx.GetSessionVars().StmtCtx.InInsertStmt = true
s.ctx.GetSessionVars().CurrInsertValues = chunk.MutRowFromDatums(currInsertValues).ToRow()
ret, err := evalBuiltinFunc(sig, chunk.Row{})
ret, err = evalBuiltinFunc(sig, chunk.Row{})
c.Assert(err, IsNil)

cmp, err := ret.CompareDatum(nil, &currInsertValues[1])
c.Assert(err, IsNil)
c.Assert(cmp, Equals, 0)
Expand Down
10 changes: 10 additions & 0 deletions expression/integration_test.go
Expand Up @@ -3604,3 +3604,13 @@ func (s *testIntegrationSuite) TestDecimalMul(c *C) {
res := tk.MustQuery("select * from t;")
res.Check(testkit.Rows("0.55125221922461136"))
}

func (s *testIntegrationSuite) TestValuesInNonInsertStmt(c *C) {
tk := testkit.NewTestKit(c, s.store)
tk.MustExec(`use test;`)
tk.MustExec(`drop table if exists t;`)
tk.MustExec(`create table t(a bigint, b double, c decimal, d varchar(20), e datetime, f time, g json);`)
tk.MustExec(`insert into t values(1, 1.1, 2.2, "abc", "2018-10-24", NOW(), "12");`)
res := tk.MustQuery(`select values(a), values(b), values(c), values(d), values(e), values(f), values(g) from t;`)
res.Check(testkit.Rows(`<nil> <nil> <nil> <nil> <nil> <nil> <nil>`))
}

0 comments on commit 3be9756

Please sign in to comment.