New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
*: make insert with calculated value behave the same as MySQL. #4603
Changes from 13 commits
fbf1a26
d0d6763
d384ee0
78e92d5
e7d14f1
13e93b9
12d5964
8ce624e
7798788
04ecc89
6a18c49
498734a
806c09e
fbed899
d23f524
2467c7d
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -643,7 +643,10 @@ type InsertValues struct { | |
batchRows int64 | ||
lastInsertID uint64 | ||
ctx context.Context | ||
SelectExec Executor | ||
|
||
needFillDefaultValues bool | ||
|
||
SelectExec Executor | ||
|
||
Table table.Table | ||
Columns []*ast.ColumnName | ||
|
@@ -885,16 +888,61 @@ func (e *InsertValues) getRows(cols []*table.Column, ignoreErr bool) (rows [][]t | |
return | ||
} | ||
|
||
// getRow eval the insert statement. Because the value of column may calculated based on other column, | ||
// it use fillDefaultValues to init the empty row before eval expressions when needFillDefaultValues is true. | ||
func (e *InsertValues) getRow(cols []*table.Column, list []expression.Expression, ignoreErr bool) ([]types.Datum, error) { | ||
vals := make([]types.Datum, len(list)) | ||
row := make([]types.Datum, len(e.Table.Cols())) | ||
hasValue := make([]bool, len(e.Table.Cols())) | ||
|
||
if e.needFillDefaultValues { | ||
if err := e.fillDefaultValues(row, hasValue, ignoreErr); err != nil { | ||
return nil, errors.Trace(err) | ||
} | ||
} | ||
|
||
for i, expr := range list { | ||
val, err := expr.Eval(nil) | ||
vals[i] = val | ||
if err != nil { | ||
val, err := expr.Eval(row) | ||
if err = e.filterErr(err, ignoreErr); err != nil { | ||
return nil, errors.Trace(err) | ||
} | ||
val, err = table.CastValue(e.ctx, val, cols[i].ToInfo()) | ||
if err = e.filterErr(err, ignoreErr); err != nil { | ||
return nil, errors.Trace(err) | ||
} | ||
|
||
offset := cols[i].Offset | ||
row[offset], hasValue[offset] = val, true | ||
} | ||
|
||
return e.fillGenColData(cols, len(list), hasValue, row, ignoreErr) | ||
} | ||
|
||
// fillDefaultValues fills a row followed by these rules: | ||
// 1. for nullable and no default value column, use NULL. | ||
// 2. for nullable and have default value column, use it's default value. | ||
// 3. for not null column, use zero value even in strict mode. | ||
// 4. for auto_increment column, use zero value. | ||
// 5. for generated column, use NULL. | ||
func (e *InsertValues) fillDefaultValues(row []types.Datum, hasValue []bool, ignoreErr bool) error { | ||
for i, c := range e.Table.Cols() { | ||
var err error | ||
if c.IsGenerated() { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. add a comment for these checks. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Will these check be influenced by sql_mode? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It use zero value for |
||
continue | ||
} else if mysql.HasAutoIncrementFlag(c.Flag) { | ||
row[i] = table.GetZeroValue(c.ToInfo()) | ||
} else { | ||
row[i], err = table.GetColDefaultValue(e.ctx, c.ToInfo()) | ||
hasValue[c.Offset] = true | ||
if table.ErrNoDefaultValue.Equal(err) { | ||
row[i] = table.GetZeroValue(c.ToInfo()) | ||
hasValue[c.Offset] = false | ||
} else if err = e.filterErr(err, ignoreErr); err != nil { | ||
return errors.Trace(err) | ||
} | ||
} | ||
} | ||
return e.fillRowData(cols, vals, ignoreErr) | ||
|
||
return nil | ||
} | ||
|
||
func (e *InsertValues) getRowsSelect(cols []*table.Column, ignoreErr bool) ([][]types.Datum, error) { | ||
|
@@ -929,6 +977,11 @@ func (e *InsertValues) fillRowData(cols []*table.Column, vals []types.Datum, ign | |
row[offset] = v | ||
hasValue[offset] = true | ||
} | ||
|
||
return e.fillGenColData(cols, len(vals), hasValue, row, ignoreErr) | ||
} | ||
|
||
func (e *InsertValues) fillGenColData(cols []*table.Column, valLen int, hasValue []bool, row []types.Datum, ignoreErr bool) ([]types.Datum, error) { | ||
err := e.initDefaultValues(row, hasValue, ignoreErr) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. why do initDefaultValues again? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. can we do these check in fillDefaultValues? |
||
if err != nil { | ||
return nil, errors.Trace(err) | ||
|
@@ -939,7 +992,7 @@ func (e *InsertValues) fillRowData(cols []*table.Column, vals []types.Datum, ign | |
if err = e.filterErr(err, ignoreErr); err != nil { | ||
return nil, errors.Trace(err) | ||
} | ||
offset := cols[len(vals)+i].Offset | ||
offset := cols[valLen+i].Offset | ||
row[offset] = val | ||
} | ||
if err = table.CastValues(e.ctx, row, cols, ignoreErr); err != nil { | ||
|
@@ -964,6 +1017,8 @@ func (e *InsertValues) filterErr(err error, ignoreErr bool) error { | |
return nil | ||
} | ||
|
||
// initDefaultValues fills generated columns, auto_increment column and empty column. | ||
// For NOT NULL column, it will return error or use zero value based on sql_mode. | ||
func (e *InsertValues) initDefaultValues(row []types.Datum, hasValue []bool, ignoreErr bool) error { | ||
var defaultValueCols []*table.Column | ||
strictSQL := e.ctx.GetSessionVars().StrictSQLMode | ||
|
@@ -980,6 +1035,9 @@ func (e *InsertValues) initDefaultValues(row []types.Datum, hasValue []bool, ign | |
// Just leave generated column as null. It will be calculated later | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Should we add a comment for initDefaultValues |
||
// but before we check whether the column can be null or not. | ||
needDefaultValue = false | ||
if !hasValue[i] { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. any test case for this check? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Test L1279-L1281: create table t(a int auto_increment key, b int);
set SQL_MODE=NO_AUTO_VALUE_ON_ZERO;
insert into t (b) value (a+1); If we don't reset |
||
row[i].SetNull() | ||
} | ||
} | ||
if needDefaultValue { | ||
var err error | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -1171,3 +1171,119 @@ func (s *testSuite) TestIssue4067(c *C) { | |
tk.MustExec("delete from t1 where id in (select id from t2)") | ||
tk.MustQuery("select * from t1").Check(nil) | ||
} | ||
|
||
func (s *testSuite) TestInsertCalculatedValue(c *C) { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. add test cases for more generated columns. |
||
tk := testkit.NewTestKit(c, s.store) | ||
tk.MustExec("use test") | ||
|
||
tk.MustExec("drop table if exists t") | ||
tk.MustExec("create table t(a int, b int)") | ||
tk.MustExec("insert into t set a=1, b=a+1") | ||
tk.MustQuery("select a, b from t").Check(testkit.Rows("1 2")) | ||
|
||
tk.MustExec("drop table if exists t") | ||
tk.MustExec("create table t(a int default 100, b int)") | ||
tk.MustExec("insert into t set b=a+1, a=1") | ||
tk.MustQuery("select a, b from t").Check(testkit.Rows("1 101")) | ||
tk.MustExec("insert into t (b) value (a)") | ||
tk.MustQuery("select * from t where b = 100").Check(testkit.Rows("100 100")) | ||
tk.MustExec("insert into t set a=2, b=a+1") | ||
tk.MustQuery("select * from t where a = 2").Check(testkit.Rows("2 3")) | ||
|
||
tk.MustExec("drop table if exists t") | ||
tk.MustExec("create table t (c int)") | ||
tk.MustExec("insert into test.t set test.t.c = '1'") | ||
tk.MustQuery("select * from t").Check(testkit.Rows("1")) | ||
|
||
tk.MustExec("drop table if exists t") | ||
tk.MustExec("create table t(a int default 1)") | ||
tk.MustExec("insert into t values (a)") | ||
tk.MustQuery("select * from t").Check(testkit.Rows("1")) | ||
|
||
tk.MustExec("drop table if exists t") | ||
tk.MustExec("create table t (a int, b int, c int, d int)") | ||
tk.MustExec("insert into t value (1, 2, a+1, b+1)") | ||
tk.MustQuery("select * from t").Check(testkit.Rows("1 2 2 3")) | ||
|
||
tk.MustExec("drop table if exists t") | ||
tk.MustExec("create table t (a int not null)") | ||
tk.MustExec("insert into t values (a+2)") | ||
tk.MustExec("insert into t values (a)") | ||
tk.MustQuery("select * from t order by a").Check(testkit.Rows("0", "2")) | ||
|
||
tk.MustExec("drop table if exists t") | ||
tk.MustExec("create table t (a bigint not null, b bigint not null)") | ||
tk.MustExec("insert into t value(b + 1, a)") | ||
tk.MustExec("insert into t set a = b + a, b = a + 1") | ||
tk.MustExec("insert into t value(1000, a)") | ||
tk.MustExec("insert t set b = sqrt(a + 4), a = 10") | ||
tk.MustQuery("select * from t order by a").Check(testkit.Rows("0 1", "1 1", "10 2", "1000 1000")) | ||
|
||
tk.MustExec("drop table if exists t") | ||
tk.MustExec("create table t(a int)") | ||
tk.MustExec("insert into t values(a)") | ||
tk.MustQuery("select * from t").Check(testkit.Rows("<nil>")) | ||
|
||
tk.MustExec("drop table if exists t") | ||
tk.MustExec("create table t(a enum('a', 'b'))") | ||
tk.MustExec("insert into t values(a)") | ||
tk.MustQuery("select * from t").Check(testkit.Rows("<nil>")) | ||
tk.MustExec("drop table if exists t") | ||
tk.MustExec("create table t(a enum('a', 'b') default 'a')") | ||
tk.MustExec("insert into t values(a)") | ||
tk.MustExec("insert into t values(a+1)") | ||
tk.MustQuery("select * from t order by a").Check(testkit.Rows("a", "b")) | ||
|
||
tk.MustExec("drop table if exists t") | ||
tk.MustExec("create table t(a blob)") | ||
tk.MustExec("insert into t values(a)") | ||
tk.MustQuery("select * from t").Check(testkit.Rows("<nil>")) | ||
|
||
tk.MustExec("drop table if exists t") | ||
tk.MustExec("create table t(a varchar(20) default 'a')") | ||
tk.MustExec("insert into t values(a)") | ||
tk.MustExec("insert into t values(upper(a))") | ||
tk.MustQuery("select * from t order by a").Check(testkit.Rows("A", "a")) | ||
tk.MustExec("drop table if exists t") | ||
tk.MustExec("create table t(a varchar(20) not null, b varchar(20))") | ||
tk.MustExec("insert into t value (a, b)") | ||
tk.MustQuery("select * from t").Check(testkit.Rows(" <nil>")) | ||
|
||
tk.MustExec("drop table if exists t") | ||
tk.MustExec("create table t(a int, b int)") | ||
tk.MustExec("insert into t values(a*b, b*b)") | ||
tk.MustQuery("select * from t").Check(testkit.Rows("<nil> <nil>")) | ||
|
||
tk.MustExec("drop table if exists t") | ||
tk.MustExec("create table t (a json not null, b int)") | ||
tk.MustExec("insert into t value (a,a->'$')") | ||
tk.MustQuery("select * from t").Check(testkit.Rows("null 0")) | ||
|
||
tk.MustExec("drop table if exists t") | ||
tk.MustExec("create table t(a json, b int, c int as (a->'$.a'))") | ||
tk.MustExec("insert into t (a, b) value (a, a->'$.a'+1)") | ||
tk.MustExec("insert into t (b) value (a->'$.a'+1)") | ||
tk.MustQuery("select * from t").Check(testkit.Rows("<nil> <nil> <nil>", "<nil> <nil> <nil>")) | ||
tk.MustExec(`insert into t (a, b) value ('{"a": 1}', a->'$.a'+1)`) | ||
tk.MustQuery("select * from t where c = 1").Check(testkit.Rows(`{"a":1} 2 1`)) | ||
tk.MustExec("truncate table t") | ||
tk.MustExec("insert t set b = c + 1") | ||
tk.MustQuery("select * from t").Check(testkit.Rows("<nil> <nil> <nil>")) | ||
tk.MustExec("truncate table t") | ||
tk.MustExec(`insert t set a = '{"a": 1}', b = c`) | ||
tk.MustQuery("select * from t").Check(testkit.Rows(`{"a":1} <nil> 1`)) | ||
|
||
tk.MustExec("drop table if exists t") | ||
tk.MustExec("create table t(a int auto_increment key, b int)") | ||
tk.MustExec("insert into t (b) value (a)") | ||
tk.MustExec("insert into t value (a, a+1)") | ||
tk.MustExec("set SQL_MODE=NO_AUTO_VALUE_ON_ZERO") | ||
tk.MustExec("insert into t (b) value (a+1)") | ||
tk.MustQuery("select * from t order by a").Check(testkit.Rows("1 0", "2 1", "3 1")) | ||
|
||
tk.MustExec("set SQL_MODE=STRICT_ALL_TABLES") | ||
tk.MustExec("drop table if exists t") | ||
tk.MustExec("create table t(a int not null, b int, c int as (sqrt(a)))") | ||
tk.MustExec("insert t set b = a, a = 4") | ||
tk.MustQuery("select * from t").Check(testkit.Rows("4 0 2")) | ||
} |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -693,7 +693,8 @@ func (b *planBuilder) buildInsert(insert *ast.InsertStmt) Plan { | |
return nil | ||
} | ||
tableInfo := tn.TableInfo | ||
schema := expression.TableInfo2Schema(tableInfo) | ||
// Build Schema with DBName otherwise ColumnRef with DBName cannot match any Column in Schema. | ||
schema := expression.TableInfo2SchemaWithDBName(tn.Schema, tableInfo) | ||
tableInPlan, ok := b.is.TableByID(tableInfo.ID) | ||
if !ok { | ||
b.err = errors.Errorf("Can't get table %s.", tableInfo.Name.O) | ||
|
@@ -733,6 +734,20 @@ func (b *planBuilder) buildInsert(insert *ast.InsertStmt) Plan { | |
} | ||
} | ||
|
||
mockTablePlan := TableDual{}.init(b.allocator, b.ctx) | ||
mockTablePlan.SetSchema(schema) | ||
|
||
checkRefColumn := func(n ast.Node) ast.Node { | ||
if insertPlan.NeedFillDefaultValue { | ||
return n | ||
} | ||
switch n.(type) { | ||
case *ast.ColumnName, *ast.ColumnNameExpr: | ||
insertPlan.NeedFillDefaultValue = true | ||
} | ||
return n | ||
} | ||
|
||
cols := insertPlan.Table.Cols() | ||
maxValuesItemLength := 0 // the max length of items in VALUES list. | ||
for _, valuesItem := range insert.Lists { | ||
|
@@ -752,7 +767,7 @@ func (b *planBuilder) buildInsert(insert *ast.InsertStmt) Plan { | |
RetType: &val.Type, | ||
} | ||
} else { | ||
expr, _, err = b.rewrite(valueItem, nil, nil, true) | ||
expr, _, err = b.rewriteWithPreprocess(valueItem, mockTablePlan, nil, true, checkRefColumn) | ||
} | ||
if err != nil { | ||
b.err = errors.Trace(err) | ||
|
@@ -784,8 +799,6 @@ func (b *planBuilder) buildInsert(insert *ast.InsertStmt) Plan { | |
} | ||
} | ||
|
||
mockTablePlan := TableDual{}.init(b.allocator, b.ctx) | ||
mockTablePlan.SetSchema(schema) | ||
for _, assign := range insert.Setlist { | ||
col, err := schema.FindColumn(assign.Column) | ||
if err != nil { | ||
|
@@ -801,9 +814,7 @@ func (b *planBuilder) buildInsert(insert *ast.InsertStmt) Plan { | |
b.err = ErrBadGeneratedColumn.GenByArgs(assign.Column.Name.O, tableInfo.Name.O) | ||
return nil | ||
} | ||
// Here we keep different behaviours with MySQL. MySQL allow set a = b, b = a and the result is NULL, NULL. | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. keep the comment There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. But after merge this PR, result of There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. ok. got it. |
||
// It's unreasonable. | ||
expr, _, err := b.rewrite(assign.Expr, mockTablePlan, nil, true) | ||
expr, _, err := b.rewriteWithPreprocess(assign.Expr, mockTablePlan, nil, true, checkRefColumn) | ||
if err != nil { | ||
b.err = errors.Trace(err) | ||
return nil | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
remove the blank line.