Skip to content

Commit

Permalink
support set variable, prepare and insert stmt. (#1359)
Browse files Browse the repository at this point in the history
  • Loading branch information
hanfei1991 committed Jun 29, 2016
1 parent 87bbb12 commit a75f677
Show file tree
Hide file tree
Showing 7 changed files with 120 additions and 16 deletions.
5 changes: 5 additions & 0 deletions executor/executor_simple_test.go
Expand Up @@ -21,6 +21,7 @@ import (
"github.com/pingcap/tidb/meta"
"github.com/pingcap/tidb/model"
"github.com/pingcap/tidb/mysql"
"github.com/pingcap/tidb/plan"
"github.com/pingcap/tidb/plan/statistics"
"github.com/pingcap/tidb/sessionctx"
"github.com/pingcap/tidb/sessionctx/variable"
Expand All @@ -30,6 +31,7 @@ import (
)

func (s *testSuite) TestCharsetDatabase(c *C) {
plan.UseNewPlanner = true
defer testleak.AfterTest(c)()
tk := testkit.NewTestKit(c, s.store)
testSQL := `create database if not exists cd_test_utf8 CHARACTER SET utf8 COLLATE utf8_bin;`
Expand All @@ -47,9 +49,11 @@ func (s *testSuite) TestCharsetDatabase(c *C) {
tk.MustExec(testSQL)
tk.MustQuery(`select @@character_set_database;`).Check(testkit.Rows("latin1"))
tk.MustQuery(`select @@collation_database;`).Check(testkit.Rows("latin1_swedish_ci"))
plan.UseNewPlanner = false
}

func (s *testSuite) TestSetVar(c *C) {
plan.UseNewPlanner = true
defer testleak.AfterTest(c)()
tk := testkit.NewTestKit(c, s.store)
testSQL := "SET @a = 1;"
Expand Down Expand Up @@ -106,6 +110,7 @@ func (s *testSuite) TestSetVar(c *C) {
testSQL = "SET @@global.autocommit=1, @issue998b=6;"
tk.MustExec(testSQL)
tk.MustQuery(`select @issue998b, @@global.autocommit;`).Check(testkit.Rows("6 1"))
plan.UseNewPlanner = false
}

func (s *testSuite) TestSetCharset(c *C) {
Expand Down
8 changes: 8 additions & 0 deletions executor/executor_test.go
Expand Up @@ -127,6 +127,7 @@ func (s *testSuite) TestAdmin(c *C) {
}

func (s *testSuite) TestPrepared(c *C) {
plan.UseNewPlanner = true
defer testleak.AfterTest(c)()
tk := testkit.NewTestKit(c, s.store)
tk.MustExec("use test")
Expand Down Expand Up @@ -182,6 +183,7 @@ func (s *testSuite) TestPrepared(c *C) {
exec.Fields()
exec.Next()
exec.Close()
plan.UseNewPlanner = false
}

func (s *testSuite) fillData(tk *testkit.TestKit, table string) {
Expand Down Expand Up @@ -591,6 +593,7 @@ func (s *testSuite) TestSelectWithoutFrom(c *C) {
}

func (s *testSuite) TestSelectLimit(c *C) {
plan.UseNewPlanner = true
defer testleak.AfterTest(c)()
tk := testkit.NewTestKit(c, s.store)
tk.MustExec("use test")
Expand Down Expand Up @@ -629,6 +632,7 @@ func (s *testSuite) TestSelectLimit(c *C) {
_, err := tk.Exec("select * from select_limit limit 18446744073709551616 offset 3;")
c.Assert(err, NotNil)
tk.MustExec("rollback")
plan.UseNewPlanner = false
}

func (s *testSuite) TestSelectOrderBy(c *C) {
Expand Down Expand Up @@ -732,6 +736,7 @@ func (s *testSuite) TestSelectDistinct(c *C) {
}

func (s *testSuite) TestSelectErrorRow(c *C) {
plan.UseNewPlanner = true
defer testleak.AfterTest(c)()
tk := testkit.NewTestKit(c, s.store)
tk.MustExec("use test")
Expand Down Expand Up @@ -762,6 +767,7 @@ func (s *testSuite) TestSelectErrorRow(c *C) {
c.Assert(err, NotNil)

tk.MustExec("commit")
plan.UseNewPlanner = false
}

func (s *testSuite) TestUpdate(c *C) {
Expand Down Expand Up @@ -1220,6 +1226,7 @@ func (s *testSuite) TestIndexReverseOrder(c *C) {
}

func (s *testSuite) TestTableReverseOrder(c *C) {
plan.UseNewPlanner = true
defer testleak.AfterTest(c)()
tk := testkit.NewTestKit(c, s.store)
tk.MustExec("use test")
Expand All @@ -1230,6 +1237,7 @@ func (s *testSuite) TestTableReverseOrder(c *C) {
result.Check(testkit.Rows("9", "8", "7", "6", "5", "4", "3", "2", "1"))
result = tk.MustQuery("select a from t where a <3 or (a >=6 and a < 8) order by a desc")
result.Check(testkit.Rows("7", "6", "2", "1"))
plan.UseNewPlanner = false
}

func (s *testSuite) TestInSubquery(c *C) {
Expand Down
3 changes: 2 additions & 1 deletion executor/executor_write.go
Expand Up @@ -24,6 +24,7 @@ import (
"github.com/pingcap/tidb/expression"
"github.com/pingcap/tidb/kv"
"github.com/pingcap/tidb/mysql"
"github.com/pingcap/tidb/plan"
"github.com/pingcap/tidb/sessionctx"
"github.com/pingcap/tidb/sessionctx/variable"
"github.com/pingcap/tidb/table"
Expand Down Expand Up @@ -650,7 +651,7 @@ func (e *InsertValues) getRow(cols []*table.Column, list []ast.ExprNode, default

func (e *InsertValues) getRowsSelect(cols []*table.Column) ([][]types.Datum, error) {
// process `insert|replace into ... select ... from ...`
if len(e.SelectExec.Fields()) != len(cols) {
if (!plan.UseNewPlanner && len(e.SelectExec.Fields()) != len(cols)) || (plan.UseNewPlanner && len(e.SelectExec.Schema()) != len(cols)) {
return nil, errors.Errorf("Column count %d doesn't match value count %d", len(cols), len(e.SelectExec.Fields()))
}
var rows [][]types.Datum
Expand Down
5 changes: 3 additions & 2 deletions expression/expression.go
Expand Up @@ -231,9 +231,10 @@ func NewFunction(funcName string, retType *types.FieldType, args ...Expression)
log.Errorf("Function %s is not implemented.", funcName)
return nil
}

funcArgs := make([]Expression, len(args))
copy(funcArgs, args)
return &ScalarFunction{
Args: args,
Args: funcArgs,
FuncName: model.NewCIStr(funcName),
RetType: retType,
Function: f.F}
Expand Down
102 changes: 93 additions & 9 deletions plan/expression_rewriter.go
@@ -1,6 +1,8 @@
package plan

import (
"strings"

"github.com/juju/errors"
"github.com/pingcap/tidb/ast"
"github.com/pingcap/tidb/context"
Expand All @@ -9,6 +11,7 @@ import (
"github.com/pingcap/tidb/infoschema"
"github.com/pingcap/tidb/mysql"
"github.com/pingcap/tidb/parser/opcode"
"github.com/pingcap/tidb/sessionctx/variable"
"github.com/pingcap/tidb/util/types"
)

Expand Down Expand Up @@ -46,8 +49,8 @@ func getRowLen(e expression.Expression) int {
if f, ok := e.(*expression.ScalarFunction); ok && f.FuncName.L == ast.RowFunc {
return len(f.Args)
}
if f, ok := e.(*expression.Constant); ok && f.RetType.Tp == types.KindRow {
return len(f.Value.GetRow())
if c, ok := e.(*expression.Constant); ok && c.Value.Kind() == types.KindRow {
return len(c.Value.GetRow())
}
return 1
}
Expand Down Expand Up @@ -231,7 +234,7 @@ func (er *expressionRewriter) Enter(inNode ast.Node) (ast.Node, bool) {
for _, col := range np.GetSchema() {
args = append(args, col.DeepCopy())
}
rexpr = expression.NewFunction(ast.RowFunc, types.NewFieldType(types.KindRow), args...)
rexpr = expression.NewFunction(ast.RowFunc, nil, args...)
}
// a in (subq) will be rewrited as a = any(subq).
// a not in (subq) will be rewrited as a != all(subq).
Expand Down Expand Up @@ -263,7 +266,7 @@ func (er *expressionRewriter) Enter(inNode ast.Node) (ast.Node, bool) {
newCols = append(newCols, col.DeepCopy())
}
er.ctxStack = append(er.ctxStack,
expression.NewFunction(ast.RowFunc, types.NewFieldType(types.KindRow), newCols...))
expression.NewFunction(ast.RowFunc, nil, newCols...))
} else {
er.ctxStack = append(er.ctxStack, np.GetSchema()[0])
}
Expand All @@ -287,7 +290,7 @@ func (er *expressionRewriter) Enter(inNode ast.Node) (ast.Node, bool) {
RetType: np.GetSchema()[i].GetType()})
}
er.ctxStack = append(er.ctxStack,
expression.NewFunction(ast.RowFunc, types.NewFieldType(types.KindRow), newCols...))
expression.NewFunction(ast.RowFunc, nil, newCols...))
} else {
er.ctxStack = append(er.ctxStack, np.GetSchema()[0])
}
Expand All @@ -312,8 +315,9 @@ func (er *expressionRewriter) Leave(inNode ast.Node) (retNode ast.Node, ok bool)
rows = append(rows, er.ctxStack[i])
}
er.ctxStack = er.ctxStack[:stkLen-length]
er.ctxStack = append(er.ctxStack, expression.NewFunction(ast.RowFunc, types.NewFieldType(types.KindRow), rows...))

er.ctxStack = append(er.ctxStack, expression.NewFunction(ast.RowFunc, nil, rows...))
case *ast.VariableExpr:
return inNode, er.rewriteVariable(v)
case *ast.FuncCallExpr:
er.funcCallToScalarFunc(v)
case *ast.PositionExpr:
Expand All @@ -326,7 +330,10 @@ func (er *expressionRewriter) Leave(inNode ast.Node) (retNode ast.Node, ok bool)
er.toColumn(v)
case *ast.ColumnNameExpr, *ast.ParenthesesExpr, *ast.WhenClause, *ast.SubqueryExpr, *ast.ExistsSubqueryExpr, *ast.CompareSubqueryExpr:
case *ast.ValueExpr:
value := &expression.Constant{Value: v.Datum, RetType: types.NewFieldType(v.Datum.Kind())}
value := &expression.Constant{Value: v.Datum, RetType: v.Type}
er.ctxStack = append(er.ctxStack, value)
case *ast.ParamMarkerExpr:
value := &expression.Constant{Value: v.Datum, RetType: v.Type}
er.ctxStack = append(er.ctxStack, value)
case *ast.IsNullExpr:
if getRowLen(er.ctxStack[stkLen-1]) != 1 {
Expand Down Expand Up @@ -356,7 +363,7 @@ func (er *expressionRewriter) Leave(inNode ast.Node) (retNode ast.Node, ok bool)
return retNode, false
}
default:
function = expression.NewFunction(opcode.Ops[v.Op], v.Type, er.ctxStack[stkLen-2], er.ctxStack[stkLen-1])
function = expression.NewFunction(opcode.Ops[v.Op], v.Type, er.ctxStack[stkLen-2:]...)
}
er.ctxStack = er.ctxStack[:stkLen-2]
er.ctxStack = append(er.ctxStack, function)
Expand Down Expand Up @@ -387,6 +394,83 @@ func (er *expressionRewriter) Leave(inNode ast.Node) (retNode ast.Node, ok bool)
return inNode, true
}

func (er *expressionRewriter) rewriteVariable(v *ast.VariableExpr) bool {
stkLen := len(er.ctxStack)
name := strings.ToLower(v.Name)
sessionVars := variable.GetSessionVars(er.b.ctx)
globalVars := variable.GetGlobalVarAccessor(er.b.ctx)
if !v.IsSystem {
var d types.Datum
var err error
if v.Value != nil {
d, err = er.ctxStack[stkLen-1].Eval(nil, er.b.ctx)
if err != nil {
er.err = errors.Trace(err)
return false
}
er.ctxStack = er.ctxStack[:stkLen-1]
}
if !d.IsNull() {

strVal, err := d.ToString()
if err != nil {
er.err = errors.Trace(err)
return false
}
sessionVars.Users[name] = strings.ToLower(strVal)
er.ctxStack = append(er.ctxStack, &expression.Constant{Value: d, RetType: types.NewFieldType(mysql.TypeString)})
} else if value, ok := sessionVars.Users[name]; ok {
er.ctxStack = append(er.ctxStack, &expression.Constant{Value: types.NewDatum(value), RetType: types.NewFieldType(mysql.TypeString)})
} else {
// select null user vars is permitted.
er.ctxStack = append(er.ctxStack, &expression.Constant{RetType: types.NewFieldType(mysql.TypeNull)})
}
return true
}

sysVar, ok := variable.SysVars[name]
if !ok {
// select null sys vars is not permitted
er.err = variable.UnknownSystemVar.Gen("Unknown system variable '%s'", name)
return false
}
if sysVar.Scope == variable.ScopeNone {
er.ctxStack = append(er.ctxStack, &expression.Constant{Value: types.NewDatum(sysVar.Value), RetType: types.NewFieldType(mysql.TypeString)})
return true
}

if !v.IsGlobal {
d := sessionVars.GetSystemVar(name)
if d.IsNull() {
if sysVar.Scope&variable.ScopeGlobal == 0 {
d.SetString(sysVar.Value)
} else {
// Get global system variable and fill it in session.
globalVal, err := globalVars.GetGlobalSysVar(er.b.ctx, name)
if err != nil {
er.err = errors.Trace(err)
return false
}
d.SetString(globalVal)
err = sessionVars.SetSystemVar(name, d)
if err != nil {
er.err = errors.Trace(err)
return false
}
}
}
er.ctxStack = append(er.ctxStack, &expression.Constant{Value: d, RetType: types.NewFieldType(mysql.TypeString)})
return true
}
value, err := globalVars.GetGlobalSysVar(er.b.ctx, name)
if err != nil {
er.err = errors.Trace(err)
return false
}
er.ctxStack = append(er.ctxStack, &expression.Constant{Value: types.NewDatum(value), RetType: types.NewFieldType(mysql.TypeString)})
return true
}

func (er *expressionRewriter) notToScalarFunc(b bool, op string, tp *types.FieldType,
args ...expression.Expression) *expression.ScalarFunction {
opFunc := expression.NewFunction(op, tp, args...)
Expand Down
10 changes: 6 additions & 4 deletions plan/newplanbuilder.go
Expand Up @@ -325,16 +325,18 @@ func (b *planBuilder) buildNewUnion(union *ast.UnionStmt) Plan {
}

u.SetSchema(firstSchema)
var p Plan
p = u
if union.Distinct {
return b.buildNewDistinct(u)
p = b.buildNewDistinct(u)
}
if union.OrderBy != nil {
return b.buildNewSort(u, union.OrderBy.Items, nil)
p = b.buildNewSort(p, union.OrderBy.Items, nil)
}
if union.Limit != nil {
return b.buildNewLimit(u, union.Limit)
p = b.buildNewLimit(p, union.Limit)
}
return u
return p
}

// ByItems wraps a "by" item.
Expand Down
3 changes: 3 additions & 0 deletions plan/planbuilder.go
Expand Up @@ -104,6 +104,9 @@ func (b *planBuilder) build(node ast.Node) Plan {
}
return b.buildSelect(x)
case *ast.UnionStmt:
if UseNewPlanner {
return b.buildNewUnion(x)
}
return b.buildUnion(x)
case *ast.UpdateStmt:
return b.buildUpdate(x)
Expand Down

0 comments on commit a75f677

Please sign in to comment.