Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

*: make insert with calculated value behave the same as MySQL. #4603

Merged
merged 16 commits into from
Sep 27, 2017
12 changes: 12 additions & 0 deletions ast/expressions.go
Original file line number Diff line number Diff line change
Expand Up @@ -345,6 +345,18 @@ func (n *ColumnName) Accept(v Visitor) (Node, bool) {
return v.Leave(n)
}

// String implements Stringer interface.
func (n *ColumnName) String() string {
result := n.Name.L
if n.Table.L != "" {
result = n.Table.L + "." + result
}
if n.Schema.L != "" {
result = n.Schema.L + "." + result
}
return result
}

// ColumnNameExpr represents a column name expression.
type ColumnNameExpr struct {
exprNode
Expand Down
59 changes: 54 additions & 5 deletions executor/write.go
Original file line number Diff line number Diff line change
Expand Up @@ -886,15 +886,56 @@ func (e *InsertValues) getRows(cols []*table.Column, ignoreErr bool) (rows [][]t
}

func (e *InsertValues) getRow(cols []*table.Column, list []expression.Expression, ignoreErr bool) ([]types.Datum, error) {
vals := make([]types.Datum, len(list))
row := make([]types.Datum, len(e.Table.Cols()))
hasValue := make([]bool, len(e.Table.Cols()))
if err := e.fillDefaultValues(row, ignoreErr); err != nil {
return nil, errors.Trace(err)
}

for i, expr := range list {
val, err := expr.Eval(nil)
vals[i] = val
val, err := expr.Eval(row)
if err != nil {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

use e.filterErr(err) here.

return nil, errors.Trace(err)
}
val, err = table.CastValue(e.ctx, val, cols[i].ToInfo())
if err != nil {
return nil, errors.Trace(err)
}

offset := cols[i].Offset
row[offset] = val
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

row[offset], hasValue[offset] = val, true

hasValue[offset] = true
if err != nil {
return nil, errors.Trace(err)
}
}
return e.fillRowData(cols, vals, ignoreErr)

return e.checkRowData(cols, len(list), hasValue, row, ignoreErr)
}

func (e *InsertValues) fillDefaultValues(row []types.Datum, ignoreErr bool) error {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

hasValue columns is filled by default values too?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

fillDefaultValues invoked before eval actual insert expression. It fill a empty row with default value, then this row role as a eval context.

var defaultValueCols []*table.Column
for i, c := range e.Table.Cols() {
var err error
if c.IsGenerated() {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

add a comment for these checks.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Will these check be influenced by sql_mode?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It use zero value for not null column even in strict mode. The column filled by fillDefaultValues will not mark in hasValue, so checkRowData will check nullability and insert correct value for column such as auto_increment based on sql_mode.

continue
} else if mysql.HasAutoIncrementFlag(c.Flag) {
row[i] = table.GetZeroValue(c.ToInfo())
} else {
row[i], err = table.GetColDefaultValue(e.ctx, c.ToInfo())
if table.IsNoDefault(err) && mysql.HasNotNullFlag(c.Flag) {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If fail to go into this branch, we should handle this error and return?

row[i] = table.GetZeroValue(c.ToInfo())
} else if err != nil {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

use filterErr(err) here?

return errors.Trace(err)
}
}
defaultValueCols = append(defaultValueCols, c)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think defaultValueCols is not necessary, just use e.Table.Cols() in table.CastValues(e.ctx, row, e.Table.Cols(), ignoreErr) ?

}
if err := table.CastValues(e.ctx, row, defaultValueCols, ignoreErr); err != nil {
return errors.Trace(err)
}

return nil
}

func (e *InsertValues) getRowsSelect(cols []*table.Column, ignoreErr bool) ([][]types.Datum, error) {
Expand Down Expand Up @@ -929,6 +970,11 @@ func (e *InsertValues) fillRowData(cols []*table.Column, vals []types.Datum, ign
row[offset] = v
hasValue[offset] = true
}

return e.checkRowData(cols, len(vals), hasValue, row, ignoreErr)
}

func (e *InsertValues) checkRowData(cols []*table.Column, valLen int, hasValue []bool, row []types.Datum, ignoreErr bool) ([]types.Datum, error) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this func name is not that explicit,
would fillGenColData be better?

err := e.initDefaultValues(row, hasValue, ignoreErr)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why do initDefaultValues again?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

fillDefaultValues just ignore sql_mode and use zero value or null for columns, the result row just a eval context so set a = 1, b = a can eval correctly. But the actual value for no value column is influenced by sql_mode and other conditions, so I simply ask initDefaultValues to fill the correct value for no explicit value columns.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can we do these check in fillDefaultValues?

if err != nil {
return nil, errors.Trace(err)
Expand All @@ -939,7 +985,7 @@ func (e *InsertValues) fillRowData(cols []*table.Column, vals []types.Datum, ign
if err = e.filterErr(err, ignoreErr); err != nil {
return nil, errors.Trace(err)
}
offset := cols[len(vals)+i].Offset
offset := cols[valLen+i].Offset
row[offset] = val
}
if err = table.CastValues(e.ctx, row, cols, ignoreErr); err != nil {
Expand Down Expand Up @@ -980,6 +1026,9 @@ func (e *InsertValues) initDefaultValues(row []types.Datum, hasValue []bool, ign
// Just leave generated column as null. It will be calculated later
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we add a comment for initDefaultValues
to distinguish it with fillDefaultValues easily.

// but before we check whether the column can be null or not.
needDefaultValue = false
if !hasValue[i] {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

any test case for this check?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Test L1279-L1281:

create table t(a int auto_increment key, b int);
set SQL_MODE=NO_AUTO_VALUE_ON_ZERO;
insert into t (b) value (a+1);

If we don't reset a to NULL, a will be 0.

row[i].SetNull()
}
}
if needDefaultValue {
var err error
Expand Down
103 changes: 103 additions & 0 deletions executor/write_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1171,3 +1171,106 @@ func (s *testSuite) TestIssue4067(c *C) {
tk.MustExec("delete from t1 where id in (select id from t2)")
tk.MustQuery("select * from t1").Check(nil)
}

func (s *testSuite) TestInsertCalculatedValue(c *C) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

add test cases for more generated columns.

tk := testkit.NewTestKit(c, s.store)
tk.MustExec("use test")

tk.MustExec("drop table if exists t")
tk.MustExec("create table t(a int, b int)")
tk.MustExec("insert into t set a=1, b=a+1")
tk.MustQuery("select a, b from t").Check(testkit.Rows("1 2"))

tk.MustExec("drop table if exists t")
tk.MustExec("create table t(a int default 100, b int)")
tk.MustExec("insert into t set b=a+1, a=1")
tk.MustQuery("select a, b from t").Check(testkit.Rows("1 101"))
tk.MustExec("insert into t (b) value (a)")
tk.MustQuery("select * from t where b = 100").Check(testkit.Rows("100 100"))
tk.MustExec("insert into t set a=2, b=a+1")
tk.MustQuery("select * from t where a = 2").Check(testkit.Rows("2 3"))

tk.MustExec("drop table if exists t")
tk.MustExec("create table t (c int)")
tk.MustExec("insert into test.t set test.t.c = '1'")
tk.MustQuery("select * from t").Check(testkit.Rows("1"))

tk.MustExec("drop table if exists t")
tk.MustExec("create table t(a int default 1)")
tk.MustExec("insert into t values (a)")
tk.MustQuery("select * from t").Check(testkit.Rows("1"))

tk.MustExec("drop table if exists t")
tk.MustExec("create table t (a int, b int, c int, d int)")
tk.MustExec("insert into t value (1, 2, a+1, b+1)")
tk.MustQuery("select * from t").Check(testkit.Rows("1 2 2 3"))

tk.MustExec("drop table if exists t")
tk.MustExec("create table t (a int not null)")
tk.MustExec("insert into t values (a+2)")
tk.MustExec("insert into t values (a)")
tk.MustQuery("select * from t order by a").Check(testkit.Rows("0", "2"))

tk.MustExec("drop table if exists t")
tk.MustExec("create table t (a bigint not null, b bigint not null)")
tk.MustExec("insert into t value(b + 1, a)")
tk.MustExec("insert into t set a = b + a, b = a + 1")
tk.MustExec("insert into t value(1000, a)")
tk.MustQuery("select * from t order by a").Check(testkit.Rows("0 1", "1 1", "1000 1000"))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

add the case issued by winoros in #4482 , create table without null flags.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

add test cases like:

mysql> create table t(a int default 1);
Query OK, 0 rows affected (0.01 sec)

mysql> insert into t values (a);
Query OK, 1 row affected (0.00 sec)

mysql> select * from t;
+------+
| a    |
+------+
|    1 |
+------+
1 row in set (0.00 sec)

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please add the conner cases like JSON, enum, key with auto increment and so on.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have added these test cases at L1221-L1275.


tk.MustExec("drop table if exists t")
tk.MustExec("create table t(a int)")
tk.MustExec("insert into t values(a)")
tk.MustQuery("select * from t").Check(testkit.Rows("<nil>"))

tk.MustExec("drop table if exists t")
tk.MustExec("create table t(a enum('a', 'b'))")
tk.MustExec("insert into t values(a)")
tk.MustQuery("select * from t").Check(testkit.Rows("<nil>"))
tk.MustExec("drop table if exists t")
tk.MustExec("create table t(a enum('a', 'b') default 'a')")
tk.MustExec("insert into t values(a)")
tk.MustExec("insert into t values(a+1)")
tk.MustQuery("select * from t order by a").Check(testkit.Rows("a", "b"))

tk.MustExec("drop table if exists t")
tk.MustExec("create table t(a blob)")
tk.MustExec("insert into t values(a)")
tk.MustQuery("select * from t").Check(testkit.Rows("<nil>"))

tk.MustExec("drop table if exists t")
tk.MustExec("create table t(a varchar(20) default 'a')")
tk.MustExec("insert into t values(a)")
tk.MustExec("insert into t values(upper(a))")
tk.MustQuery("select * from t order by a").Check(testkit.Rows("A", "a"))
tk.MustExec("drop table if exists t")
tk.MustExec("create table t(a varchar(20) not null, b varchar(20))")
tk.MustExec("insert into t value (a, b)")
tk.MustQuery("select * from t").Check(testkit.Rows(" <nil>"))

tk.MustExec("drop table if exists t")
tk.MustExec("create table t(a int, b int)")
tk.MustExec("insert into t values(a*b, b*b)")
tk.MustQuery("select * from t").Check(testkit.Rows("<nil> <nil>"))

tk.MustExec("drop table if exists t")
tk.MustExec("create table t (a json not null, b int)")
tk.MustExec("insert into t value (a,a->'$')")
tk.MustQuery("select * from t").Check(testkit.Rows("null 0"))

tk.MustExec("drop table if exists t")
tk.MustExec("create table t(a json, b int, c int as (a->'$.a'))")
tk.MustExec("insert into t (a, b) value (a, a->'$.a'+1)")
tk.MustExec("insert into t (b) value (a->'$.a'+1)")
tk.MustQuery("select * from t").Check(testkit.Rows("<nil> <nil> <nil>", "<nil> <nil> <nil>"))
tk.MustExec(`insert into t (a, b) value ('{"a": 1}', a->'$.a'+1)`)
tk.MustQuery("select * from t where c = 1").Check(testkit.Rows(`{"a":1} 2 1`))

tk.MustExec("drop table if exists t")
tk.MustExec("create table t(a int auto_increment key, b int)")
tk.MustExec("insert into t (b) value (a)")
tk.MustExec("insert into t value (a, a+1)")
tk.MustExec("set SQL_MODE=NO_AUTO_VALUE_ON_ZERO")
tk.MustExec("insert into t (b) value (a+1)")
tk.MustQuery("select * from t order by a").Check(testkit.Rows("1 0", "2 1", "3 1"))
}
18 changes: 15 additions & 3 deletions expression/expression.go
Original file line number Diff line number Diff line change
Expand Up @@ -338,9 +338,14 @@ func EvaluateExprWithNull(ctx context.Context, schema *Schema, expr Expression)
}
}

// TableInfo2Schema converts table info to schema.
// TableInfo2Schema converts table info to schema with empty DBName.
func TableInfo2Schema(tbl *model.TableInfo) *Schema {
cols := ColumnInfos2Columns(tbl.Name, tbl.Columns)
return TableInfo2SchemaWithDBName(model.CIStr{}, tbl)
}

// TableInfo2SchemaWithDBName converts table info to schema.
func TableInfo2SchemaWithDBName(dbName model.CIStr, tbl *model.TableInfo) *Schema {
cols := ColumnInfos2ColumnsWithDBName(dbName, tbl.Name, tbl.Columns)
keys := make([]KeyInfo, 0, len(tbl.Indices)+1)
for _, idx := range tbl.Indices {
if !idx.Unique || idx.State != model.StatePublic {
Expand Down Expand Up @@ -382,15 +387,22 @@ func TableInfo2Schema(tbl *model.TableInfo) *Schema {
return schema
}

// ColumnInfos2Columns converts a slice of ColumnInfo to a slice of Column.
// ColumnInfos2Columns converts a slice of ColumnInfo to a slice of Column with empty DBName.
func ColumnInfos2Columns(tblName model.CIStr, colInfos []*model.ColumnInfo) []*Column {
return ColumnInfos2ColumnsWithDBName(model.CIStr{}, tblName, colInfos)
}

// ColumnInfos2ColumnsWithDBName converts a slice of ColumnInfo to a slice of Column.
func ColumnInfos2ColumnsWithDBName(dbName, tblName model.CIStr, colInfos []*model.ColumnInfo) []*Column {
columns := make([]*Column, 0, len(colInfos))
for i, col := range colInfos {
newCol := &Column{
ColName: col.Name,
TblName: tblName,
DBName: dbName,
RetType: &col.FieldType,
Position: i,
Index: col.Offset,
}
columns = append(columns, newCol)
}
Expand Down
2 changes: 1 addition & 1 deletion plan/expression_rewriter.go
Original file line number Diff line number Diff line change
Expand Up @@ -1150,5 +1150,5 @@ func (er *expressionRewriter) toColumn(v *ast.ColumnName) {
return
}
}
er.err = ErrUnknownColumn.GenByArgs(v.Text(), "field list")
er.err = ErrUnknownColumn.GenByArgs(v.String(), "field list")
}
12 changes: 6 additions & 6 deletions plan/planbuilder.go
Original file line number Diff line number Diff line change
Expand Up @@ -693,7 +693,8 @@ func (b *planBuilder) buildInsert(insert *ast.InsertStmt) Plan {
return nil
}
tableInfo := tn.TableInfo
schema := expression.TableInfo2Schema(tableInfo)
// Build Schema with DBName otherwise ColumnRef with DBName cannot match any Column in Schema
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

add . in the end of line 696

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

add . at the end.

schema := expression.TableInfo2SchemaWithDBName(tn.Schema, tableInfo)
tableInPlan, ok := b.is.TableByID(tableInfo.ID)
if !ok {
b.err = errors.Errorf("Can't get table %s.", tableInfo.Name.O)
Expand Down Expand Up @@ -733,6 +734,9 @@ func (b *planBuilder) buildInsert(insert *ast.InsertStmt) Plan {
}
}

mockTablePlan := TableDual{}.init(b.allocator, b.ctx)
mockTablePlan.SetSchema(schema)

cols := insertPlan.Table.Cols()
maxValuesItemLength := 0 // the max length of items in VALUES list.
for _, valuesItem := range insert.Lists {
Expand All @@ -752,7 +756,7 @@ func (b *planBuilder) buildInsert(insert *ast.InsertStmt) Plan {
RetType: &val.Type,
}
} else {
expr, _, err = b.rewrite(valueItem, nil, nil, true)
expr, _, err = b.rewrite(valueItem, mockTablePlan, nil, true)
}
if err != nil {
b.err = errors.Trace(err)
Expand Down Expand Up @@ -784,8 +788,6 @@ func (b *planBuilder) buildInsert(insert *ast.InsertStmt) Plan {
}
}

mockTablePlan := TableDual{}.init(b.allocator, b.ctx)
mockTablePlan.SetSchema(schema)
for _, assign := range insert.Setlist {
col, err := schema.FindColumn(assign.Column)
if err != nil {
Expand All @@ -801,8 +803,6 @@ func (b *planBuilder) buildInsert(insert *ast.InsertStmt) Plan {
b.err = ErrBadGeneratedColumn.GenByArgs(assign.Column.Name.O, tableInfo.Name.O)
return nil
}
// Here we keep different behaviours with MySQL. MySQL allow set a = b, b = a and the result is NULL, NULL.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

keep the comment

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

But after merge this PR, result of set a = b, b = a is NULL, NULL.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ok. got it.

// It's unreasonable.
expr, _, err := b.rewrite(assign.Expr, mockTablePlan, nil, true)
if err != nil {
b.err = errors.Trace(err)
Expand Down
8 changes: 8 additions & 0 deletions table/column.go
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ import (
"github.com/pingcap/tidb/mysql"
"github.com/pingcap/tidb/util/hack"
"github.com/pingcap/tidb/util/types"
"github.com/pingcap/tidb/util/types/json"
)

// Column provides meta data describing a table column.
Expand Down Expand Up @@ -344,6 +345,11 @@ func getColDefaultValueFromNil(ctx context.Context, col *model.ColumnInfo) (type
return types.Datum{}, errNoDefaultValue.Gen("Field '%s' doesn't have a default value", col.Name)
}

// IsNoDefault check if err is equal to errNoDefaultValue.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Seems like the IsNoDefault is not necessary?

table.IsNoDefault and errNoDefaultValue.Equal(err) is the same. you can just set errNoDefaultValue to ErrNoDefaultValue

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

s/ check/ checks

func IsNoDefault(err error) bool {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No need to define this func,
make errNoDefaultValue to be exported.

return errNoDefaultValue.Equal(err)
}

// GetZeroValue gets zero value for given column type.
func GetZeroValue(col *model.ColumnInfo) types.Datum {
var d types.Datum
Expand Down Expand Up @@ -378,6 +384,8 @@ func GetZeroValue(col *model.ColumnInfo) types.Datum {
d.SetMysqlSet(types.Set{})
case mysql.TypeEnum:
d.SetMysqlEnum(types.Enum{})
case mysql.TypeJSON:
d.SetMysqlJSON(json.CreateJSON(nil))
}
return d
}