Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

*: make insert with calculated value behave the same as MySQL. #4603

Merged
merged 16 commits into from Sep 27, 2017
12 changes: 12 additions & 0 deletions ast/expressions.go
Expand Up @@ -345,6 +345,18 @@ func (n *ColumnName) Accept(v Visitor) (Node, bool) {
return v.Leave(n)
}

// String implements Stringer interface.
func (n *ColumnName) String() string {
result := n.Name.L
if n.Table.L != "" {
result = n.Table.L + "." + result
}
if n.Schema.L != "" {
result = n.Schema.L + "." + result
}
return result
}

// ColumnNameExpr represents a column name expression.
type ColumnNameExpr struct {
exprNode
Expand Down
13 changes: 7 additions & 6 deletions executor/builder.go
Expand Up @@ -292,12 +292,13 @@ func (b *executorBuilder) buildSet(v *plan.Set) Executor {

func (b *executorBuilder) buildInsert(v *plan.Insert) Executor {
ivs := &InsertValues{
ctx: b.ctx,
Columns: v.Columns,
Lists: v.Lists,
Setlist: v.Setlist,
GenColumns: v.GenCols.Columns,
GenExprs: v.GenCols.Exprs,
ctx: b.ctx,
Columns: v.Columns,
Lists: v.Lists,
Setlist: v.Setlist,
GenColumns: v.GenCols.Columns,
GenExprs: v.GenCols.Exprs,
needFillDefaultValues: v.NeedFillDefaultValue,
}
if len(v.Children()) > 0 {
ivs.SelectExec = b.build(v.Children()[0])
Expand Down
72 changes: 65 additions & 7 deletions executor/write.go
Expand Up @@ -643,7 +643,10 @@ type InsertValues struct {
batchRows int64
lastInsertID uint64
ctx context.Context
SelectExec Executor

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

remove the blank line.

needFillDefaultValues bool

SelectExec Executor

Table table.Table
Columns []*ast.ColumnName
Expand Down Expand Up @@ -885,16 +888,61 @@ func (e *InsertValues) getRows(cols []*table.Column, ignoreErr bool) (rows [][]t
return
}

// getRow eval the insert statement. Because the value of column may calculated based on other column,
// it use fillDefaultValues to init the empty row before eval expressions when needFillDefaultValues is true.
func (e *InsertValues) getRow(cols []*table.Column, list []expression.Expression, ignoreErr bool) ([]types.Datum, error) {
vals := make([]types.Datum, len(list))
row := make([]types.Datum, len(e.Table.Cols()))
hasValue := make([]bool, len(e.Table.Cols()))

if e.needFillDefaultValues {
if err := e.fillDefaultValues(row, hasValue, ignoreErr); err != nil {
return nil, errors.Trace(err)
}
}

for i, expr := range list {
val, err := expr.Eval(nil)
vals[i] = val
if err != nil {
val, err := expr.Eval(row)
if err = e.filterErr(err, ignoreErr); err != nil {
return nil, errors.Trace(err)
}
val, err = table.CastValue(e.ctx, val, cols[i].ToInfo())
if err = e.filterErr(err, ignoreErr); err != nil {
return nil, errors.Trace(err)
}

offset := cols[i].Offset
row[offset], hasValue[offset] = val, true
}

return e.fillGenColData(cols, len(list), hasValue, row, ignoreErr)
}

// fillDefaultValues fills a row followed by these rules:
// 1. for nullable and no default value column, use NULL.
// 2. for nullable and have default value column, use it's default value.
// 3. for not null column, use zero value even in strict mode.
// 4. for auto_increment column, use zero value.
// 5. for generated column, use NULL.
func (e *InsertValues) fillDefaultValues(row []types.Datum, hasValue []bool, ignoreErr bool) error {
for i, c := range e.Table.Cols() {
var err error
if c.IsGenerated() {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

add a comment for these checks.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Will these check be influenced by sql_mode?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It use zero value for not null column even in strict mode. The column filled by fillDefaultValues will not mark in hasValue, so checkRowData will check nullability and insert correct value for column such as auto_increment based on sql_mode.

continue
} else if mysql.HasAutoIncrementFlag(c.Flag) {
row[i] = table.GetZeroValue(c.ToInfo())
} else {
row[i], err = table.GetColDefaultValue(e.ctx, c.ToInfo())
hasValue[c.Offset] = true
if table.ErrNoDefaultValue.Equal(err) {
row[i] = table.GetZeroValue(c.ToInfo())
hasValue[c.Offset] = false
} else if err = e.filterErr(err, ignoreErr); err != nil {
return errors.Trace(err)
}
}
}
return e.fillRowData(cols, vals, ignoreErr)

return nil
}

func (e *InsertValues) getRowsSelect(cols []*table.Column, ignoreErr bool) ([][]types.Datum, error) {
Expand Down Expand Up @@ -929,6 +977,11 @@ func (e *InsertValues) fillRowData(cols []*table.Column, vals []types.Datum, ign
row[offset] = v
hasValue[offset] = true
}

return e.fillGenColData(cols, len(vals), hasValue, row, ignoreErr)
}

func (e *InsertValues) fillGenColData(cols []*table.Column, valLen int, hasValue []bool, row []types.Datum, ignoreErr bool) ([]types.Datum, error) {
err := e.initDefaultValues(row, hasValue, ignoreErr)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why do initDefaultValues again?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

fillDefaultValues just ignore sql_mode and use zero value or null for columns, the result row just a eval context so set a = 1, b = a can eval correctly. But the actual value for no value column is influenced by sql_mode and other conditions, so I simply ask initDefaultValues to fill the correct value for no explicit value columns.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can we do these check in fillDefaultValues?

if err != nil {
return nil, errors.Trace(err)
Expand All @@ -939,7 +992,7 @@ func (e *InsertValues) fillRowData(cols []*table.Column, vals []types.Datum, ign
if err = e.filterErr(err, ignoreErr); err != nil {
return nil, errors.Trace(err)
}
offset := cols[len(vals)+i].Offset
offset := cols[valLen+i].Offset
row[offset] = val
}
if err = table.CastValues(e.ctx, row, cols, ignoreErr); err != nil {
Expand All @@ -964,6 +1017,8 @@ func (e *InsertValues) filterErr(err error, ignoreErr bool) error {
return nil
}

// initDefaultValues fills generated columns, auto_increment column and empty column.
// For NOT NULL column, it will return error or use zero value based on sql_mode.
func (e *InsertValues) initDefaultValues(row []types.Datum, hasValue []bool, ignoreErr bool) error {
var defaultValueCols []*table.Column
strictSQL := e.ctx.GetSessionVars().StrictSQLMode
Expand All @@ -980,6 +1035,9 @@ func (e *InsertValues) initDefaultValues(row []types.Datum, hasValue []bool, ign
// Just leave generated column as null. It will be calculated later
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we add a comment for initDefaultValues
to distinguish it with fillDefaultValues easily.

// but before we check whether the column can be null or not.
needDefaultValue = false
if !hasValue[i] {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

any test case for this check?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Test L1279-L1281:

create table t(a int auto_increment key, b int);
set SQL_MODE=NO_AUTO_VALUE_ON_ZERO;
insert into t (b) value (a+1);

If we don't reset a to NULL, a will be 0.

row[i].SetNull()
}
}
if needDefaultValue {
var err error
Expand Down
116 changes: 116 additions & 0 deletions executor/write_test.go
Expand Up @@ -1171,3 +1171,119 @@ func (s *testSuite) TestIssue4067(c *C) {
tk.MustExec("delete from t1 where id in (select id from t2)")
tk.MustQuery("select * from t1").Check(nil)
}

func (s *testSuite) TestInsertCalculatedValue(c *C) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

add test cases for more generated columns.

tk := testkit.NewTestKit(c, s.store)
tk.MustExec("use test")

tk.MustExec("drop table if exists t")
tk.MustExec("create table t(a int, b int)")
tk.MustExec("insert into t set a=1, b=a+1")
tk.MustQuery("select a, b from t").Check(testkit.Rows("1 2"))

tk.MustExec("drop table if exists t")
tk.MustExec("create table t(a int default 100, b int)")
tk.MustExec("insert into t set b=a+1, a=1")
tk.MustQuery("select a, b from t").Check(testkit.Rows("1 101"))
tk.MustExec("insert into t (b) value (a)")
tk.MustQuery("select * from t where b = 100").Check(testkit.Rows("100 100"))
tk.MustExec("insert into t set a=2, b=a+1")
tk.MustQuery("select * from t where a = 2").Check(testkit.Rows("2 3"))

tk.MustExec("drop table if exists t")
tk.MustExec("create table t (c int)")
tk.MustExec("insert into test.t set test.t.c = '1'")
tk.MustQuery("select * from t").Check(testkit.Rows("1"))

tk.MustExec("drop table if exists t")
tk.MustExec("create table t(a int default 1)")
tk.MustExec("insert into t values (a)")
tk.MustQuery("select * from t").Check(testkit.Rows("1"))

tk.MustExec("drop table if exists t")
tk.MustExec("create table t (a int, b int, c int, d int)")
tk.MustExec("insert into t value (1, 2, a+1, b+1)")
tk.MustQuery("select * from t").Check(testkit.Rows("1 2 2 3"))

tk.MustExec("drop table if exists t")
tk.MustExec("create table t (a int not null)")
tk.MustExec("insert into t values (a+2)")
tk.MustExec("insert into t values (a)")
tk.MustQuery("select * from t order by a").Check(testkit.Rows("0", "2"))

tk.MustExec("drop table if exists t")
tk.MustExec("create table t (a bigint not null, b bigint not null)")
tk.MustExec("insert into t value(b + 1, a)")
tk.MustExec("insert into t set a = b + a, b = a + 1")
tk.MustExec("insert into t value(1000, a)")
tk.MustExec("insert t set b = sqrt(a + 4), a = 10")
tk.MustQuery("select * from t order by a").Check(testkit.Rows("0 1", "1 1", "10 2", "1000 1000"))

tk.MustExec("drop table if exists t")
tk.MustExec("create table t(a int)")
tk.MustExec("insert into t values(a)")
tk.MustQuery("select * from t").Check(testkit.Rows("<nil>"))

tk.MustExec("drop table if exists t")
tk.MustExec("create table t(a enum('a', 'b'))")
tk.MustExec("insert into t values(a)")
tk.MustQuery("select * from t").Check(testkit.Rows("<nil>"))
tk.MustExec("drop table if exists t")
tk.MustExec("create table t(a enum('a', 'b') default 'a')")
tk.MustExec("insert into t values(a)")
tk.MustExec("insert into t values(a+1)")
tk.MustQuery("select * from t order by a").Check(testkit.Rows("a", "b"))

tk.MustExec("drop table if exists t")
tk.MustExec("create table t(a blob)")
tk.MustExec("insert into t values(a)")
tk.MustQuery("select * from t").Check(testkit.Rows("<nil>"))

tk.MustExec("drop table if exists t")
tk.MustExec("create table t(a varchar(20) default 'a')")
tk.MustExec("insert into t values(a)")
tk.MustExec("insert into t values(upper(a))")
tk.MustQuery("select * from t order by a").Check(testkit.Rows("A", "a"))
tk.MustExec("drop table if exists t")
tk.MustExec("create table t(a varchar(20) not null, b varchar(20))")
tk.MustExec("insert into t value (a, b)")
tk.MustQuery("select * from t").Check(testkit.Rows(" <nil>"))

tk.MustExec("drop table if exists t")
tk.MustExec("create table t(a int, b int)")
tk.MustExec("insert into t values(a*b, b*b)")
tk.MustQuery("select * from t").Check(testkit.Rows("<nil> <nil>"))

tk.MustExec("drop table if exists t")
tk.MustExec("create table t (a json not null, b int)")
tk.MustExec("insert into t value (a,a->'$')")
tk.MustQuery("select * from t").Check(testkit.Rows("null 0"))

tk.MustExec("drop table if exists t")
tk.MustExec("create table t(a json, b int, c int as (a->'$.a'))")
tk.MustExec("insert into t (a, b) value (a, a->'$.a'+1)")
tk.MustExec("insert into t (b) value (a->'$.a'+1)")
tk.MustQuery("select * from t").Check(testkit.Rows("<nil> <nil> <nil>", "<nil> <nil> <nil>"))
tk.MustExec(`insert into t (a, b) value ('{"a": 1}', a->'$.a'+1)`)
tk.MustQuery("select * from t where c = 1").Check(testkit.Rows(`{"a":1} 2 1`))
tk.MustExec("truncate table t")
tk.MustExec("insert t set b = c + 1")
tk.MustQuery("select * from t").Check(testkit.Rows("<nil> <nil> <nil>"))
tk.MustExec("truncate table t")
tk.MustExec(`insert t set a = '{"a": 1}', b = c`)
tk.MustQuery("select * from t").Check(testkit.Rows(`{"a":1} <nil> 1`))

tk.MustExec("drop table if exists t")
tk.MustExec("create table t(a int auto_increment key, b int)")
tk.MustExec("insert into t (b) value (a)")
tk.MustExec("insert into t value (a, a+1)")
tk.MustExec("set SQL_MODE=NO_AUTO_VALUE_ON_ZERO")
tk.MustExec("insert into t (b) value (a+1)")
tk.MustQuery("select * from t order by a").Check(testkit.Rows("1 0", "2 1", "3 1"))

tk.MustExec("set SQL_MODE=STRICT_ALL_TABLES")
tk.MustExec("drop table if exists t")
tk.MustExec("create table t(a int not null, b int, c int as (sqrt(a)))")
tk.MustExec("insert t set b = a, a = 4")
tk.MustQuery("select * from t").Check(testkit.Rows("4 0 2"))
}
18 changes: 15 additions & 3 deletions expression/expression.go
Expand Up @@ -335,9 +335,14 @@ func EvaluateExprWithNull(ctx context.Context, schema *Schema, expr Expression)
}
}

// TableInfo2Schema converts table info to schema.
// TableInfo2Schema converts table info to schema with empty DBName.
func TableInfo2Schema(tbl *model.TableInfo) *Schema {
cols := ColumnInfos2Columns(tbl.Name, tbl.Columns)
return TableInfo2SchemaWithDBName(model.CIStr{}, tbl)
}

// TableInfo2SchemaWithDBName converts table info to schema.
func TableInfo2SchemaWithDBName(dbName model.CIStr, tbl *model.TableInfo) *Schema {
cols := ColumnInfos2ColumnsWithDBName(dbName, tbl.Name, tbl.Columns)
keys := make([]KeyInfo, 0, len(tbl.Indices)+1)
for _, idx := range tbl.Indices {
if !idx.Unique || idx.State != model.StatePublic {
Expand Down Expand Up @@ -379,15 +384,22 @@ func TableInfo2Schema(tbl *model.TableInfo) *Schema {
return schema
}

// ColumnInfos2Columns converts a slice of ColumnInfo to a slice of Column.
// ColumnInfos2Columns converts a slice of ColumnInfo to a slice of Column with empty DBName.
func ColumnInfos2Columns(tblName model.CIStr, colInfos []*model.ColumnInfo) []*Column {
return ColumnInfos2ColumnsWithDBName(model.CIStr{}, tblName, colInfos)
}

// ColumnInfos2ColumnsWithDBName converts a slice of ColumnInfo to a slice of Column.
func ColumnInfos2ColumnsWithDBName(dbName, tblName model.CIStr, colInfos []*model.ColumnInfo) []*Column {
columns := make([]*Column, 0, len(colInfos))
for i, col := range colInfos {
newCol := &Column{
ColName: col.Name,
TblName: tblName,
DBName: dbName,
RetType: &col.FieldType,
Position: i,
Index: col.Offset,
}
columns = append(columns, newCol)
}
Expand Down
2 changes: 1 addition & 1 deletion plan/expression_rewriter.go
Expand Up @@ -1150,5 +1150,5 @@ func (er *expressionRewriter) toColumn(v *ast.ColumnName) {
return
}
}
er.err = ErrUnknownColumn.GenByArgs(v.Text(), "field list")
er.err = ErrUnknownColumn.GenByArgs(v.String(), "field list")
}
25 changes: 18 additions & 7 deletions plan/planbuilder.go
Expand Up @@ -693,7 +693,8 @@ func (b *planBuilder) buildInsert(insert *ast.InsertStmt) Plan {
return nil
}
tableInfo := tn.TableInfo
schema := expression.TableInfo2Schema(tableInfo)
// Build Schema with DBName otherwise ColumnRef with DBName cannot match any Column in Schema.
schema := expression.TableInfo2SchemaWithDBName(tn.Schema, tableInfo)
tableInPlan, ok := b.is.TableByID(tableInfo.ID)
if !ok {
b.err = errors.Errorf("Can't get table %s.", tableInfo.Name.O)
Expand Down Expand Up @@ -733,6 +734,20 @@ func (b *planBuilder) buildInsert(insert *ast.InsertStmt) Plan {
}
}

mockTablePlan := TableDual{}.init(b.allocator, b.ctx)
mockTablePlan.SetSchema(schema)

checkRefColumn := func(n ast.Node) ast.Node {
if insertPlan.NeedFillDefaultValue {
return n
}
switch n.(type) {
case *ast.ColumnName, *ast.ColumnNameExpr:
insertPlan.NeedFillDefaultValue = true
}
return n
}

cols := insertPlan.Table.Cols()
maxValuesItemLength := 0 // the max length of items in VALUES list.
for _, valuesItem := range insert.Lists {
Expand All @@ -752,7 +767,7 @@ func (b *planBuilder) buildInsert(insert *ast.InsertStmt) Plan {
RetType: &val.Type,
}
} else {
expr, _, err = b.rewrite(valueItem, nil, nil, true)
expr, _, err = b.rewriteWithPreprocess(valueItem, mockTablePlan, nil, true, checkRefColumn)
}
if err != nil {
b.err = errors.Trace(err)
Expand Down Expand Up @@ -784,8 +799,6 @@ func (b *planBuilder) buildInsert(insert *ast.InsertStmt) Plan {
}
}

mockTablePlan := TableDual{}.init(b.allocator, b.ctx)
mockTablePlan.SetSchema(schema)
for _, assign := range insert.Setlist {
col, err := schema.FindColumn(assign.Column)
if err != nil {
Expand All @@ -801,9 +814,7 @@ func (b *planBuilder) buildInsert(insert *ast.InsertStmt) Plan {
b.err = ErrBadGeneratedColumn.GenByArgs(assign.Column.Name.O, tableInfo.Name.O)
return nil
}
// Here we keep different behaviours with MySQL. MySQL allow set a = b, b = a and the result is NULL, NULL.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

keep the comment

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

But after merge this PR, result of set a = b, b = a is NULL, NULL.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ok. got it.

// It's unreasonable.
expr, _, err := b.rewrite(assign.Expr, mockTablePlan, nil, true)
expr, _, err := b.rewriteWithPreprocess(assign.Expr, mockTablePlan, nil, true, checkRefColumn)
if err != nil {
b.err = errors.Trace(err)
return nil
Expand Down
3 changes: 3 additions & 0 deletions plan/plans.go
Expand Up @@ -136,6 +136,9 @@ type Insert struct {
Priority mysql.PriorityEnum
IgnoreErr bool

// NeedFillDefaultValue is true when expr in value list reference other column.
NeedFillDefaultValue bool

GenCols InsertGeneratedColumns
}

Expand Down