Skip to content

Commit

Permalink
mysql, sessionctx: refine set/get variable 'SQL_MODE' (#4530)
Browse files Browse the repository at this point in the history
  • Loading branch information
spongedu authored and zz-jason committed Sep 15, 2017
1 parent a7bcc2b commit 74d5ce4
Show file tree
Hide file tree
Showing 5 changed files with 103 additions and 22 deletions.
21 changes: 21 additions & 0 deletions expression/integration_test.go
Expand Up @@ -2751,3 +2751,24 @@ func (s *testIntegrationSuite) TestFuncJSON(c *C) {
r = tk.MustQuery(`select json_extract(json_object(1,2,3,4), '$."1"')`)
r.Check(testkit.Rows("2"))
}

func (s *testIntegrationSuite) TestSetVariables(c *C) {
tk := testkit.NewTestKit(c, s.store)
defer func() {
s.cleanEnv(c)
testleak.AfterTest(c)()
}()
_, err := tk.Exec("set sql_mode='adfasdfadsfdasd';")
c.Assert(err, NotNil)
_, err = tk.Exec("set @@sql_mode='adfasdfadsfdasd';")
c.Assert(err, NotNil)
_, err = tk.Exec("set @@global.sql_mode='adfasdfadsfdasd';")
c.Assert(err, NotNil)
_, err = tk.Exec("set @@session.sql_mode='adfasdfadsfdasd';")
c.Assert(err, NotNil)

var r *testkit.Result
_, err = tk.Exec("set @@session.sql_mode=',NO_ZERO_DATE';")
r = tk.MustQuery(`select @@session.sql_mode`)
r.Check(testkit.Rows("NO_ZERO_DATE"))
}
39 changes: 31 additions & 8 deletions mysql/const.go
@@ -1,4 +1,4 @@
// Copyright 2015 PingCAP, Inc.
// Copyright 2017 PingCAP, Inc.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
Expand All @@ -18,6 +18,10 @@ import (
"strings"
)

func newInvalidModeErr(s string) error {
return NewErr(ErrWrongValueForVar, "sql_mode", s)
}

// Version information.
const (
MinProtocolVersion byte = 10
Expand Down Expand Up @@ -425,14 +429,33 @@ const (
ModePadCharToFullLength
)

// GetSQLMode gets the sql mode for string literal.
func GetSQLMode(str string) SQLMode {
str = strings.ToUpper(str)
mode, ok := Str2SQLMode[str]
if !ok {
return ModeNone
// FormatSQLModeStr re-format 'SQL_MODE' variable.
func FormatSQLModeStr(s string) string {
s = strings.ToUpper(strings.TrimRight(s, " "))
parts := strings.Split(s, ",")
var nonEmptyParts []string
for i := 0; i < len(parts); i++ {
if len(parts[i]) == 0 {
continue
}
nonEmptyParts = append(nonEmptyParts, parts[i])
}
return strings.Join(nonEmptyParts, ",")
}

// GetSQLMode gets the sql mode for string literal. SQL_mode is a list of different modes separated by commas.
// The input string must be formatted by 'FormatSQLModeStr'
func GetSQLMode(s string) (SQLMode, error) {
strs := strings.Split(s, ",")
var sqlMode SQLMode
for i, length := 0, len(strs); i < length; i++ {
mode, ok := Str2SQLMode[strs[i]]
if !ok && strs[i] != "" {
return sqlMode, newInvalidModeErr(strs[i])
}
sqlMode = sqlMode | mode
}
return mode
return sqlMode, nil
}

// Str2SQLMode is the string represent of sql_mode to sql_mode map.
Expand Down
43 changes: 36 additions & 7 deletions mysql/const_test.go
@@ -1,4 +1,4 @@
// Copyright 2015 PingCAP, Inc.
// Copyright 2017 PingCAP, Inc.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
Expand All @@ -16,7 +16,6 @@ package mysql
import (
. "github.com/pingcap/check"
"github.com/pingcap/tidb/util/testleak"
"strings"
"testing"
)

Expand All @@ -30,6 +29,40 @@ var _ = Suite(&testMySQLConstSuite{})
type testMySQLConstSuite struct {
}

func (s *testMySQLConstSuite) TestGetSQLMode(c *C) {
defer testleak.AfterTest(c)()

positiveCases := []struct {
arg string
}{
{"NO_ZERO_DATE"},
{",,NO_ZERO_DATE"},
{"NO_ZERO_DATE,NO_ZERO_IN_DATE"},
{""},
{", "},
{","},
}

for _, t := range positiveCases {
_, err := GetSQLMode(FormatSQLModeStr(t.arg))
c.Assert(err, IsNil)
}

negativeCases := []struct {
arg string
}{
{"NO_ZERO_DATE, NO_ZERO_IN_DATE"},
{"NO_ZERO_DATE,adfadsdfasdfads"},
{", ,NO_ZERO_DATE"},
{" ,"},
}

for _, t := range negativeCases {
_, err := GetSQLMode(FormatSQLModeStr(t.arg))
c.Assert(err, NotNil)
}
}

func (s *testMySQLConstSuite) TestSQLMode(c *C) {
defer testleak.AfterTest(c)()

Expand All @@ -46,11 +79,7 @@ func (s *testMySQLConstSuite) TestSQLMode(c *C) {
}

for _, t := range tests {
modes := strings.Split(t.arg, ",")
var sqlMode SQLMode
for _, mode := range modes {
sqlMode = sqlMode | GetSQLMode(mode)
}
sqlMode, _ := GetSQLMode(t.arg)
c.Assert(sqlMode.HasNoZeroDateMode(), Equals, t.hasNoZeroDateMode)
c.Assert(sqlMode.HasNoZeroInDateMode(), Equals, t.hasNoZeroInDateMode)
}
Expand Down
6 changes: 6 additions & 0 deletions session.go
Expand Up @@ -616,6 +616,12 @@ func (s *session) GetGlobalSysVar(name string) (string, error) {

// SetGlobalSysVar implements GlobalVarAccessor.SetGlobalSysVar interface.
func (s *session) SetGlobalSysVar(name string, value string) error {
if name == variable.SQLModeVar {
value = mysql.FormatSQLModeStr(value)
if _, err := mysql.GetSQLMode(value); err != nil {
return errors.Trace(err)
}
}
sql := fmt.Sprintf(`REPLACE %s.%s VALUES ('%s', '%s');`,
mysql.SystemDB, mysql.GlobalVariablesTable, strings.ToLower(name), value)
_, _, err := s.ExecRestrictedSQL(s, sql)
Expand Down
16 changes: 9 additions & 7 deletions sessionctx/varsutil/varsutil.go
Expand Up @@ -39,7 +39,6 @@ func GetSessionSystemVar(s *variable.SessionVars, key string) (string, error) {
case variable.TiDBCurrentTS:
return fmt.Sprintf("%d", s.TxnCtx.StartTS), nil
}

sVal, ok := s.Systems[key]
if ok {
return sVal, nil
Expand Down Expand Up @@ -68,7 +67,11 @@ func GetGlobalSystemVar(s *variable.SessionVars, key string) (string, error) {
} else if sysVar.Scope == variable.ScopeNone {
return sysVar.Value, nil
}
return s.GlobalVarsAccessor.GetGlobalSysVar(key)
gVal, err := s.GlobalVarsAccessor.GetGlobalSysVar(key)
if err != nil {
return "", errors.Trace(err)
}
return gVal, nil
}

// epochShiftBits is used to reserve logical part of the timestamp.
Expand Down Expand Up @@ -99,18 +102,17 @@ func SetSessionSystemVar(vars *variable.SessionVars, name string, value types.Da
return errors.Trace(err)
}
case variable.SQLModeVar:
sVal = strings.ToUpper(sVal)
sVal = mysql.FormatSQLModeStr(sVal)
// TODO: Remove this latter.
if strings.Contains(sVal, "STRICT_TRANS_TABLES") || strings.Contains(sVal, "STRICT_ALL_TABLES") {
vars.StrictSQLMode = true
} else {
vars.StrictSQLMode = false
}
// Modes is a list of different modes separated by commas.
modes := strings.Split(sVal, ",")
var sqlMode mysql.SQLMode
for _, mode := range modes {
sqlMode = sqlMode | mysql.GetSQLMode(mode)
sqlMode, err2 := mysql.GetSQLMode(sVal)
if err2 != nil {
return errors.Trace(err2)
}
vars.SQLMode = sqlMode
case variable.TiDBSnapshot:
Expand Down

0 comments on commit 74d5ce4

Please sign in to comment.