Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

planner/core: add DefaultExpr support for expressionRewriter #8540

Merged
merged 6 commits into from Jan 2, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
87 changes: 87 additions & 0 deletions planner/core/expression_rewriter.go
Expand Up @@ -27,6 +27,7 @@ import (
"github.com/pingcap/tidb/infoschema"
"github.com/pingcap/tidb/sessionctx"
"github.com/pingcap/tidb/sessionctx/variable"
"github.com/pingcap/tidb/table"
"github.com/pingcap/tidb/types"
"github.com/pingcap/tidb/types/parser_driver"
"github.com/pingcap/tidb/util/chunk"
Expand Down Expand Up @@ -169,6 +170,7 @@ type expressionRewriter struct {
insertPlan *Insert
}

// constructBinaryOpFunction converts binary operator functions
// 1. If op are EQ or NE or NullEQ, constructBinaryOpFunctions converts (a0,a1,a2) op (b0,b1,b2) to (a0 op b0) and (a1 op b1) and (a2 op b2)
// 2. Else constructBinaryOpFunctions converts (a0,a1,a2) op (b0,b1,b2) to
// `IF( a0 NE b0, a0 op b0,
Expand Down Expand Up @@ -799,6 +801,8 @@ func (er *expressionRewriter) Leave(originInNode ast.Node) (retNode ast.Node, ok
er.isNullToExpression(v)
case *ast.IsTruthExpr:
er.isTrueToScalarFunc(v)
case *ast.DefaultExpr:
er.evalDefaultExpr(v)
default:
er.err = errors.Errorf("UnknownType: %T", v)
return retNode, false
Expand Down Expand Up @@ -1292,3 +1296,86 @@ func (er *expressionRewriter) toColumn(v *ast.ColumnName) {
}
er.err = ErrUnknownColumn.GenWithStackByArgs(v.String(), clauseMsg[er.b.curClause])
}

func (er *expressionRewriter) evalDefaultExpr(v *ast.DefaultExpr) {
stkLen := len(er.ctxStack)
var colExpr *expression.Column
switch c := er.ctxStack[stkLen-1].(type) {
case *expression.Column:
colExpr = c
case *expression.CorrelatedColumn:
colExpr = &c.Column
default:
colExpr, er.err = er.schema.FindColumn(v.Name)
if er.err != nil {
er.err = errors.Trace(er.err)
return
}
if colExpr == nil {
er.err = ErrUnknownColumn.GenWithStackByArgs(v.Name.OrigColName(), "field_list")
return
}
}
dbName := colExpr.DBName
if dbName.O == "" {
// if database name is not specified, use current database name
dbName = model.NewCIStr(er.ctx.GetSessionVars().CurrentDB)
}
if colExpr.OrigTblName.O == "" {
eurekaka marked this conversation as resolved.
Show resolved Hide resolved
// column is evaluated by some expressions, for example:
// `select default(c) from (select (a+1) as c from t) as t0`
// in such case, a 'no default' error is returned
er.err = table.ErrNoDefaultValue.GenWithStackByArgs(colExpr.ColName)
return
}
var tbl table.Table
tbl, er.err = er.b.is.TableByName(dbName, colExpr.OrigTblName)
if er.err != nil {
return
}
colName := colExpr.OrigColName.O
if colName == "" {
// in some cases, OrigColName is empty, use ColName instead
colName = colExpr.ColName.O
}
col := table.FindCol(tbl.Cols(), colName)
if col == nil {
er.err = ErrUnknownColumn.GenWithStackByArgs(v.Name, "field_list")
return
}
isCurrentTimestamp := hasCurrentDatetimeDefault(col)
var val *expression.Constant
switch {
case isCurrentTimestamp && col.Tp == mysql.TypeDatetime:
// for DATETIME column with current_timestamp, use NULL to be compatible with MySQL 5.7
val = expression.Null
case isCurrentTimestamp && col.Tp == mysql.TypeTimestamp:
// for TIMESTAMP column with current_timestamp, use 0 to be compatible with MySQL 5.7
zero := types.Time{
Time: types.ZeroTime,
Type: mysql.TypeTimestamp,
Fsp: col.Decimal,
}
val = &expression.Constant{
Value: types.NewDatum(zero),
RetType: types.NewFieldType(mysql.TypeTimestamp),
}
default:
// for other columns, just use what it is
val, er.err = er.b.getDefaultValue(col)
}
if er.err != nil {
return
}
er.ctxStack = er.ctxStack[:stkLen-1]
er.ctxStack = append(er.ctxStack, val)
}

// hasCurrentDatetimeDefault checks if column has current_timestamp default value
func hasCurrentDatetimeDefault(col *table.Column) bool {
x, ok := col.DefaultValue.(string)
if !ok {
return false
}
return strings.ToLower(x) == ast.CurrentTimestamp
}
71 changes: 71 additions & 0 deletions planner/core/expression_rewriter_test.go
Expand Up @@ -17,6 +17,7 @@ import (
. "github.com/pingcap/check"
"github.com/pingcap/tidb/util/testkit"
"github.com/pingcap/tidb/util/testleak"
"github.com/pingcap/tidb/util/testutil"
)

var _ = Suite(&testExpressionRewriterSuite{})
Expand Down Expand Up @@ -58,3 +59,73 @@ func (s *testExpressionRewriterSuite) TestBinaryOpFunction(c *C) {
tk.MustQuery("SELECT * FROM t WHERE (a,b,c) <= (1,2,3) order by b").Check(testkit.Rows("1 1 <nil>", "1 2 3"))
tk.MustQuery("SELECT * FROM t WHERE (a,b,c) > (1,2,3) order by b").Check(testkit.Rows("1 3 <nil>"))
}

func (s *testExpressionRewriterSuite) TestDefaultFunction(c *C) {
defer testleak.AfterTest(c)()
store, dom, err := newStoreWithBootstrap()
c.Assert(err, IsNil)
tk := testkit.NewTestKit(c, store)
defer func() {
dom.Close()
store.Close()
}()
tk.MustExec("use test")
tk.MustExec("drop table if exists t1")
tk.MustExec(`create table t1(
a varchar(10) default 'def',
b varchar(10),
c int default '10',
d double default '3.14',
e datetime default '20180101',
f datetime default current_timestamp);`)
tk.MustExec("insert into t1(a, b, c, d) values ('1', '1', 1, 1)")
tk.MustQuery(`select
default(a) as defa,
default(b) as defb,
default(c) as defc,
default(d) as defd,
default(e) as defe,
default(f) as deff
from t1`).Check(testutil.RowsWithSep("|", "def|<nil>|10|3.14|2018-01-01 00:00:00|<nil>"))
zz-jason marked this conversation as resolved.
Show resolved Hide resolved
err = tk.ExecToErr("select default(x) from t1")
c.Assert(err.Error(), Equals, "[planner:1054]Unknown column 'x' in 'field list'")

tk.MustQuery("select default(a0) from (select a as a0 from t1) as t0").Check(testkit.Rows("def"))
eurekaka marked this conversation as resolved.
Show resolved Hide resolved
err = tk.ExecToErr("select default(a0) from (select a+1 as a0 from t1) as t0")
c.Assert(err.Error(), Equals, "[table:1364]Field 'a0' doesn't have a default value")

tk.MustExec("create table t2(a varchar(10), b varchar(10))")
tk.MustExec("insert into t2 values ('1', '1')")
err = tk.ExecToErr("select default(a) from t1, t2")
c.Assert(err.Error(), Equals, "[planner:1052]Column 'a' in field list is ambiguous")
tk.MustQuery("select default(t1.a) from t1, t2").Check(testkit.Rows("def"))

tk.MustExec(`create table t3(
a datetime default current_timestamp,
b timestamp default current_timestamp,
c timestamp(6) default current_timestamp(6),
d varchar(20) default 'current_timestamp')`)
tk.MustExec("insert into t3 values ()")
tk.MustQuery(`select
default(a) as defa,
default(b) as defb,
default(c) as defc,
default(d) as defd
from t3`).Check(testutil.RowsWithSep("|", "<nil>|0000-00-00 00:00:00|0000-00-00 00:00:00.000000|current_timestamp"))

tk.MustExec(`create table t4(a int default 1, b varchar(5))`)
tk.MustExec(`insert into t4 values (0, 'B'), (1, 'B'), (2, 'B')`)
tk.MustExec(`create table t5(d int default 0, e varchar(5))`)
tk.MustExec(`insert into t5 values (5, 'B')`)

tk.MustQuery(`select a from t4 where a > (select default(d) from t5 where t4.b = t5.e)`).Check(testkit.Rows("1", "2"))
tk.MustQuery(`select a from t4 where a > (select default(a) from t5 where t4.b = t5.e)`).Check(testkit.Rows("2"))

tk.MustExec("prepare stmt from 'select default(a) from t1';")
tk.MustQuery("execute stmt").Check(testkit.Rows("def"))
tk.MustExec("alter table t1 modify a varchar(10) default 'DEF'")
tk.MustQuery("execute stmt").Check(testkit.Rows("DEF"))

tk.MustExec("update t1 set c = c + default(c)")
tk.MustQuery("select c from t1").Check(testkit.Rows("11"))
}
2 changes: 1 addition & 1 deletion table/column.go
Expand Up @@ -390,7 +390,7 @@ func getColDefaultValueFromNil(ctx sessionctx.Context, col *model.ColumnInfo) (t
sc.AppendWarning(ErrColumnCantNull.GenWithStackByArgs(col.Name))
return GetZeroValue(col), nil
}
return types.Datum{}, ErrNoDefaultValue.GenWithStack("Field '%s' doesn't have a default value", col.Name)
return types.Datum{}, ErrNoDefaultValue.GenWithStackByArgs(col.Name)
}

// GetZeroValue gets zero value for given column type.
Expand Down
2 changes: 1 addition & 1 deletion table/table.go
Expand Up @@ -58,7 +58,7 @@ var (

// ErrNoDefaultValue is used when insert a row, the column value is not given, and the column has not null flag
// and it doesn't have a default value.
ErrNoDefaultValue = terror.ClassTable.New(codeNoDefaultValue, "field doesn't have a default value")
ErrNoDefaultValue = terror.ClassTable.New(codeNoDefaultValue, mysql.MySQLErrName[mysql.ErrNoDefaultForField])
// ErrIndexOutBound returns for index column offset out of bound.
ErrIndexOutBound = terror.ClassTable.New(codeIndexOutBound, "index column offset out of bound")
// ErrUnsupportedOp returns for unsupported operation.
Expand Down